Learning Goals
3 min- Understand a tree as learned if-else questions that split the data.
- Train a
DecisionTreeClassifierand read its accuracy. - Visualise the tree and the feature importances.
- Control
max_depthto prevent overfitting.
Warm-Up · The 20-Questions Game
5 minIs petal_length < 2.5?
├── yes → setosa
└── no → Is petal_width < 1.75?
├── yes → versicolor
└── no → virginicaA tree learns which question to ask first (the most informative split), then the next, and so on. Each leaf is a prediction. Unlike KNN, no scaling is needed — trees split on one feature at a time, so units don't matter.
A tree greedily picks the question that best separates the classes, then repeats on each branch. The result is human-readable rules — which is why trees are loved in regulated fields where you must explain decisions.
New Concept · Training & Reading a Tree
14 minTrain
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier X, y = load_iris(return_X_y=True) Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=0) tree = DecisionTreeClassifier(max_depth=3, random_state=0) tree.fit(Xtr, ytr) print("accuracy:", tree.score(Xte, yte).round(3))
Read the rules as text
from sklearn.tree import export_text print(export_text(tree, feature_names=load_iris().feature_names))
|--- petal length (cm) <= 2.45 | |--- class: 0 |--- petal length (cm) > 2.45 | |--- petal width (cm) <= 1.75 | | |--- class: 1 | |--- petal width (cm) > 1.75 | | |--- class: 2
Visualise it
import matplotlib.pyplot as plt from sklearn.tree import plot_tree plt.figure(figsize=(12, 6)) plot_tree(tree, feature_names=load_iris().feature_names, class_names=load_iris().target_names, filled=True) plt.savefig("tree.png", dpi=120)
Feature importances
for name, imp in zip(load_iris().feature_names, tree.feature_importances_): print(f" {name:<22} {imp:.2f}")
Importance = how much each feature reduced impurity across the tree. Iris's petals dominate; sepals barely matter.
Overfitting & max_depth
no max_depth tree grows until every leaf is pure → memorises noise max_depth=3 forces simplicity → generalises better min_samples_leaf=5 another way to keep leaves from being too specific
Worked Example · Depth vs Overfitting
12 min# tree_depth.py — watch a tree overfit as it deepens from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier X, y = load_breast_cancer(return_X_y=True) Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25, random_state=0) print(f"{'depth':>5} {'train':>7} {'test':>7}") for d in [1, 2, 3, 5, 10, None]: t = DecisionTreeClassifier(max_depth=d, random_state=0).fit(Xtr, ytr) tr = t.score(Xtr, ytr) te = t.score(Xte, yte) print(f"{str(d):>5} {tr:>7.1%} {te:>7.1%}")
Sample output
depth train test
1 92.5% 88.8%
2 94.6% 91.6%
3 97.4% 93.7%
5 99.5% 93.0%
10 100.0% 91.6%
None 100.0% 90.9%Read the diff
Watch the gap. At depth 3, train and test are close — the tree generalises. By depth 10+, train hits 100% but test drops — the tree memorised training quirks (overfitting). The sweet spot here is depth 3. This train-vs-test gap is the universal sign of overfitting (Lesson 25).
Try It Yourself
13 minTrain a depth-2 tree on iris. Print its rules with export_text.
Plot the feature importances of a tree as a horizontal bar chart, sorted.
Use cross-validation to choose max_depth from 1-15 for any dataset. Plot CV accuracy vs depth.
Hint
from sklearn.model_selection import cross_val_score scores = [cross_val_score(DecisionTreeClassifier(max_depth=d, random_state=0), X, y, cv=5).mean() for d in range(1, 16)] best = scores.index(max(scores)) + 1 print("best depth:", best)
Mini-Challenge · Explain a Prediction
8 minTrees can explain individual predictions. Use decision_path to print, for one test sample, the exact chain of questions the tree asked to reach its answer.
Show one possible solution
import numpy as np from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier iris = load_iris() tree = DecisionTreeClassifier(max_depth=3, random_state=0).fit(iris.data, iris.target) sample = iris.data[100].reshape(1, -1) node_idx = tree.decision_path(sample).indices feat = tree.tree_.feature thr = tree.tree_.threshold print("Decision path:") for n in node_idx: if feat[n] != -2: # not a leaf f = iris.feature_names[feat[n]] op = "<=" if sample[0, feat[n]] <= thr[n] else ">" print(f" {f} = {sample[0, feat[n]]:.1f} {op} {thr[n]:.2f}") print("Prediction:", iris.target_names[tree.predict(sample)[0]])
Non-negotiables: trace the actual nodes visited for one sample. This explainability is a tree's superpower — most models can't do this.
Recap
3 minA decision tree learns yes/no questions that split classes apart. It needs no scaling, gives readable rules and feature importances, and can explain individual predictions. Its weakness: deep trees overfit. Control with max_depth / min_samples_leaf, tuned by CV. Next: combine many trees into a forest.
Vocabulary Card
- decision tree
- A model of learned if-else splits; each leaf is a prediction.
- split
- A question on one feature that divides the data to separate classes.
- feature importance
- How much each feature contributed to the tree's splits.
- max_depth
- The cap on how many questions deep the tree can go — your main overfitting control.
Homework
4 minTrain a tree on any dataset. Produce: the train-vs-test-by-depth table, a feature-importance chart, and the tree visualisation at the best depth. Write one sentence naming the most important feature and whether it surprised you.
Combine the depth table (worked example) with the importance chart and plot_tree. The sentence interprets the top importance.