勾配ブースティングの可視化

2.4.6

勾配ブースティングの可視化

最終更新 2020-04-08 読了時間 2 分
まとめ
  • 勾配ブースティングの可視化では、各ステージの木がどの残差を補正したかを追うことで学習過程を理解できる。
  • 初期モデルから最終モデルまでの予測の遷移を見ると、過学習や学習不足の兆候を判断しやすい。
  • n_estimatorslearning_rate の設定意図を、可視化結果と結びつけて説明できるようになる。

直感 #

このページの核心は、最終予測だけでなく「途中の予測」を見ることです。各木が前段の誤差をどの領域で埋めたかを確認すると、勾配ブースティングが段階的に関数を組み立てる過程を具体的に理解できます。

詳細な解説 #

import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
from sklearn.ensemble import GradientBoostingRegressor

1. 学習と最終予測の確認 #

まずは簡単なデータに対して勾配ブースティング回帰を学習し、最終的な予測曲線を確認します。

n_samples = 500
X = np.linspace(-10, 10, n_samples)[:, np.newaxis]
noise = np.random.rand(X.shape[0]) * 10
y = (np.sin(X).ravel()) * 10 + 10 + noise

# モデル作成
n_estimators = 10
learning_rate = 0.5
reg = GradientBoostingRegressor(
    n_estimators=n_estimators,
    learning_rate=learning_rate,
)
reg.fit(X, y)

# 予測
y_pred = reg.predict(X)

# 可視化
plt.figure(figsize=(20, 10))
plt.scatter(X, y, c="k", marker="x", label="訓練データ")
plt.plot(X, y_pred, c="r", label="最終予測", linewidth=1)
plt.axhline(y=np.mean(y), color="gray", linestyle=":", label="初期モデル(平均値)")
plt.xlabel("x"); plt.ylabel("y")
plt.title("勾配ブースティングの最終予測")
plt.legend(); plt.show()

まずは簡単なデータに対して勾配ブースティング回帰を学習し、最終的な予測曲線を確認しますの図

解説

  • 灰色の破線は「初期モデル」(平均値のみの予測)。
  • 赤い線が 10本の木を足し合わせた最終予測
  • 初期値から出発して、木を追加するごとに予測が改良されていきます。

2. 木ごとの寄与を積み上げて表示 #

次に「各木がどれだけ予測を修正したか」を棒グラフで積み上げます。

fig, ax = plt.subplots(figsize=(20, 10))
temp = np.zeros(n_samples) + np.mean(y)  # 初期モデルは平均値

for i in range(n_estimators):
    # i本目の木の予測値 × learning_rate が寄与部分
    res = reg.estimators_[i][0].predict(X) * learning_rate
    ax.bar(X.flatten(), res, bottom=temp, label=f"{i+1} 本目の木", alpha=0.05)
    temp += res  # 累積して次へ

# データと最終予測を重ねる
plt.scatter(X.flatten(), y, c="k", marker="x", label="訓練データ")
plt.plot(X, y_pred, c="r", label="最終予測", linewidth=1)
plt.xlabel("x"); plt.ylabel("y")
plt.title("木ごとの寄与を積み上げた可視化")
plt.legend(); plt.show()

次に「各木がどれだけ予測を修正したか」を棒グラフで積み上げますの図

解説

  • 薄い棒が「各木がどれだけ予測を修正したか」を示します。
  • それらを累積すると、最終的に赤い予測曲線になります。
  • 学習率 learning_rate を掛けることで、一歩ずつ慎重に修正しています。

3. 途中までの積み上げ(段階的な改善) #

さらに「木を5本まで足した時点」での予測を順に可視化します。

for i in range(5):
    fig, ax = plt.subplots(figsize=(20, 10))
    plt.title(f"{i+1} 本目までの寄与で作られた予測")
    temp = np.zeros(n_samples) + np.mean(y)

    for j in range(i + 1):
        res = reg.estimators_[j][0].predict(X) * learning_rate
        ax.bar(X.flatten(), res, bottom=temp, label=f"{j+1} 本目", alpha=0.05)
        temp += res

    # データと予測を描画
    plt.scatter(X.flatten(), y, c="k", marker="x", label="訓練データ")
    plt.plot(X, temp, c="r", linewidth=1.2, label="途中の予測")
    plt.xlabel("x"); plt.ylabel("y")
    plt.legend(); plt.show()

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

さらに「木を5本まで足した時点」での予測を順に可視化しますの図

解説

  • 1本目:大まかに残差を補正
  • 2〜3本目:細かいパターンに対応
  • 5本目:だいぶ曲線に近づいてくる
  • さらに木を追加すると、最終的に赤い曲線(完成版)になります