Decision Tree Parameters

There are various parameters in a decision tree, and the results change depending on how they are specified. In this page, we will try to visualize and check how each parameter works.

  • max_depth specifies the maximum depth of the tree
  • min_samples_split specifies the minimum number of data required to create a branch.
  • min_samples_leaf specifies the minimum number of data required to create a leaf.
  • max_leaf_nodes specifies the maximum number of leaves.
  • ccp_alpha is a parameter for pruning the decision tree to account for tree complexity
  • class_weight specifies the weighting of classes in classification.
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from mpl_toolkits.mplot3d import Axes3D
from dtreeviz.trees import dtreeviz, rtreeviz_bivar_3D

Applying a decision tree to simple dataset

# dataset
X, y = make_regression(n_samples=100, n_features=2, random_state=11)

# train decision tree
dt = DecisionTreeRegressor(max_depth=3)
dt.fit(X, y)

# visualize
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="MPG",
    elev=40,
    azim=120,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

Mempelajari pohon keputusan dengan berbagai parameter

Mari kita periksa bagaimana pohon keputusan dengan struktur yang sedikit kompleks berperilaku ketika parameter pohon keputusan diubah. Pertama, periksa pohon keputusan dengan nilai default untuk semua parameter kecuali max_depth=3.

X, y = make_regression(
    n_samples=500, n_features=2, effective_rank=4, noise=0.1, random_state=1
)
plt.figure(figsize=(10, 10))
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()

dt = DecisionTreeRegressor(max_depth=3, random_state=117117)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

png

max_depth = 10

When the value of max_depth is large, a deeper and more complex tree is created. This can represent complex rules, but may be over-fitting if the number of data is small.

dt = DecisionTreeRegressor(max_depth=10, random_state=117117)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

max-depth=5

dt = DecisionTreeRegressor(max_depth=5, random_state=117117)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

min_samples_split=60

Specifies the minimum number of data required to create a single split. Smaller numbers of min_samples_split allow for more detailed rules. If you increase the number, you can avoid over-fittinging.

dt = DecisionTreeRegressor(max_depth=5, min_samples_split=60, random_state=117117)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

ccp_alpha=0.4

This parameter penalizes the complexity of the tree. The higher the value of ccp_alpha, the simpler the tree will be.

dt = DecisionTreeRegressor(max_depth=5, random_state=117117, ccp_alpha=0.4)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

max_leaf_nodes=5

This parameter specifies the number of leaves that will eventually be created. The number of max_leaf_nodes matches the number of parcels.

dt = DecisionTreeRegressor(max_depth=5, random_state=117117, max_leaf_nodes=5)
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

when dataset contain outliers

Specify which criterion to apply when creating a branch. Let’s see how the tree changes when criterion="squared_error" is specified with outliers. Since squared_error penalizes outliers more strongly than absolute_error, it is expected that a decision tree branch will be created if squared_error is specified.

## Multiply some data values by 5 as outlier
X, y = make_regression(n_samples=100, n_features=2, random_state=11)
y[1:20] = y[1:20] * 5
dt = DecisionTreeRegressor(max_depth=5, random_state=117117, criterion="absolute_error")
dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

dt = DecisionTreeRegressor(max_depth=5, random_state=117117, criterion="squared_error")

dt.fit(X, y)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
    dt,
    X,
    y,
    feature_names=["x1", "x2"],
    target_name="y",
    elev=40,
    azim=240,
    dist=8.0,
    show={"splits", "title"},
    ax=ax,
)
plt.show()

png

Comments

(Comments will appear after approval)