There are various parameters in a decision tree, and the results change depending on how they are specified. In this page, we will try to visualize and check how each parameter works.
max_depth
specifies the maximum depth of the treemin_samples_split
specifies the minimum number of data required to create a branch.min_samples_leaf
specifies the minimum number of data required to create a leaf.max_leaf_nodes
specifies the maximum number of leaves.ccp_alpha
is a parameter for pruning the decision tree to account for tree complexityclass_weight
specifies the weighting of classes in classification.import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from mpl_toolkits.mplot3d import Axes3D
from dtreeviz.trees import dtreeviz, rtreeviz_bivar_3D
# dataset
X, y = make_regression(n_samples=100, n_features=2, random_state=11)
# train decision tree
dt = DecisionTreeRegressor(max_depth=3)
dt.fit(X, y)
# visualize
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="MPG",
elev=40,
azim=120,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
Mari kita periksa bagaimana pohon keputusan dengan struktur yang sedikit kompleks berperilaku ketika parameter pohon keputusan diubah. Pertama, periksa pohon keputusan dengan nilai default untuk semua parameter kecuali max_depth=3
.
X, y = make_regression(
n_samples=500, n_features=2, effective_rank=4, noise=0.1, random_state=1
)
plt.figure(figsize=(10, 10))
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
dt = DecisionTreeRegressor(max_depth=3, random_state=117117)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
When the value of max_depth
is large, a deeper and more complex tree is created.
This can represent complex rules, but may be over-fitting if the number of data is small.
dt = DecisionTreeRegressor(max_depth=10, random_state=117117)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
dt = DecisionTreeRegressor(max_depth=5, random_state=117117)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
Specifies the minimum number of data required to create a single split.
Smaller numbers of min_samples_split
allow for more detailed rules. If you increase the number, you can avoid over-fittinging.
dt = DecisionTreeRegressor(max_depth=5, min_samples_split=60, random_state=117117)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
This parameter penalizes the complexity of the tree. The higher the value of ccp_alpha
, the simpler the tree will be.
dt = DecisionTreeRegressor(max_depth=5, random_state=117117, ccp_alpha=0.4)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
This parameter specifies the number of leaves that will eventually be created. The number of max_leaf_nodes
matches the number of parcels.
dt = DecisionTreeRegressor(max_depth=5, random_state=117117, max_leaf_nodes=5)
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
Specify which criterion to apply when creating a branch.
Let’s see how the tree changes when criterion="squared_error"
is specified with outliers.
Since squared_error
penalizes outliers more strongly than absolute_error
, it is expected that a decision tree branch will be created if squared_error
is specified.
## Multiply some data values by 5 as outlier
X, y = make_regression(n_samples=100, n_features=2, random_state=11)
y[1:20] = y[1:20] * 5
dt = DecisionTreeRegressor(max_depth=5, random_state=117117, criterion="absolute_error")
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()
dt = DecisionTreeRegressor(max_depth=5, random_state=117117, criterion="squared_error")
dt.fit(X, y)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
t = rtreeviz_bivar_3D(
dt,
X,
y,
feature_names=["x1", "x2"],
target_name="y",
elev=40,
azim=240,
dist=8.0,
show={"splits", "title"},
ax=ax,
)
plt.show()