Focal Loss

上級

Focal Loss

最終更新 2026-03-03 読了時間 2 分
まとめ
  • Focal Loss はクロスエントロピーに減衰項 $(1-p_t)^\gamma$ を乗じ、簡単なサンプルの損失を下げることで学習を難しいサンプルに集中させる。
  • 物体検出(RetinaNet)で提案されたが、表形式の不均衡分類にも効果的。
  • $\gamma = 0$ で通常のクロスエントロピーと一致し、$\gamma$ を上げるほど難例重視。

直感 #

不均衡データでクロスエントロピーを使うと、多数派の「簡単なサンプル」が損失の大部分を占め、少数派の学習が進まない。Focal Loss は「すでに高い確率で正しく分類できるサンプル」の損失を自動的に減衰させ、誤分類しやすい境界付近のサンプルに勾配を集中させる。

詳細な解説 #

数式 #

$$ \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) $$
  • $p_t$: 正解クラスの予測確率
  • $\gamma$: フォーカスパラメータ(通常 2.0)
  • $\alpha_t$: クラスの重み

可視化 #

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import numpy as np
import matplotlib.pyplot as plt

def focal_loss(p, gamma):
    return -(1 - p) ** gamma * np.log(p + 1e-8)

p = np.linspace(0.01, 0.99, 200)
plt.figure(figsize=(8, 4))
for g in [0, 0.5, 1, 2, 5]:
    plt.plot(p, focal_loss(p, g), label=f"γ = {g}")
plt.xlabel("Predicted probability (correct class)")
plt.ylabel("Loss")
plt.title("Focal Loss for different γ values")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

PyTorch 実装 #

1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F

def focal_loss_fn(logits, targets, gamma=2.0, alpha=0.25):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    p_t = torch.exp(-bce)
    loss = alpha * (1 - p_t) ** gamma * bce
    return loss.mean()

scikit-learn + カスタム損失 #

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from sklearn.datasets import make_classification
from lightgbm import LGBMClassifier
import numpy as np

def focal_objective(y_true, y_pred):
    gamma = 2.0
    p = 1.0 / (1.0 + np.exp(-y_pred))
    grad = p - y_true + gamma * (y_true * (1 - p)**gamma * np.log(p + 1e-8) * p * (1 - p)
            + (1 - y_true) * p**gamma * np.log(1 - p + 1e-8) * p * (1 - p))
    hess = np.abs(grad) * (1 - np.abs(grad))
    return grad, hess

X, y = make_classification(n_samples=2000, weights=[0.95, 0.05], random_state=42)
model = LGBMClassifier(objective=focal_objective, n_estimators=100, verbose=-1)

Cross-Entropy vs Focal Loss #

損失難例への集中クラス重み調整パラメータ
Cross-Entropyなしclass_weight で可
Weighted CEなし(一律重み)αα
Focal Lossあり(自動)α + γα, γ