Visualizing Gradient Boosting

2.4.6

Visualizing Gradient Boosting

Last updated 2020-04-08 Read time 2 min
Summary
  • Visualizing Gradient Boosting stage by stage reveals how each tree corrects remaining residuals.
  • Tracking intermediate predictions helps diagnose underfitting, overfitting, and unstable learning schedules.
  • The visualization makes learning_rate and n_estimators tuning decisions interpretable rather than trial-and-error.

Intuition #

The key idea is to inspect not only the final curve but also each intermediate stage. Seeing where each tree adds or subtracts prediction mass makes the boosting mechanism concrete.

Detailed Explanation #

Train and final prediction #

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n_samples = 500
X = np.linspace(-10, 10, n_samples)[:, np.newaxis]
noise = np.random.rand(X.shape[0]) * 10
y = (np.sin(X).ravel()) * 10 + 10 + noise

n_estimators = 10
learning_rate = 0.5
reg = GradientBoostingRegressor(
    n_estimators=n_estimators,
    learning_rate=learning_rate,
)
reg.fit(X, y)
y_pred = reg.predict(X)

plt.figure(figsize=(20, 10))
plt.scatter(X, y, c="k", marker="x", label="train")
plt.plot(X, y_pred, c="r", label="final", linewidth=1)
plt.xlabel("x")
plt.ylabel("y")
plt.axhline(y=np.mean(y), color="gray", linestyle=":", label="baseline")
plt.title("Fitting on training data")
plt.legend()
plt.show()

Train and final prediction figure

Stack per-tree contributions #

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
fig, ax = plt.subplots(figsize=(20, 10))
temp = np.zeros(n_samples) + np.mean(y)

for i in range(n_estimators):
    res = reg.estimators_[i][0].predict(X) * learning_rate
    ax.bar(X.flatten(), res, bottom=temp, label=f"tree {i+1}", alpha=0.05)
    temp += res

plt.scatter(X.flatten(), y, c="k", marker="x", label="train")
plt.plot(X, y_pred, c="r", label="final", linewidth=1)
plt.legend()
plt.xlabel("x")
plt.ylabel("y")

Stack per-tree contributions figure

Partial stacking (staged improvement) #

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
for i in range(5):
    fig, ax = plt.subplots(figsize=(20, 10))
    plt.title(f"Up to tree {i+1}")
    temp = np.zeros(n_samples) + np.mean(y)

    for j in range(i + 1):
        res = reg.estimators_[j][0].predict(X) * learning_rate
        ax.bar(X.flatten(), res, bottom=temp, label=f"{j+1}", alpha=0.05)
        temp += res

    plt.scatter(X.flatten(), y, c="k", marker="x", label="train")
    plt.legend()
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()