🌿 Decision Trees#
Welcome to the Decision Tree dojo, where data gets sliced, diced, and neatly organized into if–then–else statements — basically, the algorithmic version of your mom deciding what to cook:
“If it’s raining → make pakoras ☔ Else if it’s sunny → ice cream 🍦 Else → leftovers 😅”
That, my friend, is a Decision Tree.
🌱 The Core Idea#
A Decision Tree works by asking a series of binary questions that split your data into smaller and smaller groups — until each group is so pure it could join a yoga retreat.
Example:
“Is income > ₹60,000?” “Yes? → Go right 🌳” “No? → Go left 🍂”
Each split reduces uncertainty — kind of like narrowing down who ate the last slice of pizza at the office.
🎯 The Goal: Minimize Impurity#
Decision Trees are obsessed with purity — not moral, but mathematical purity. They use measures like:
Metric |
Meaning |
|---|---|
Gini Impurity |
“How mixed-up is this node?” (0 = perfectly pure) |
Entropy |
Borrowed from physics — aka “How chaotic is this node?” |
When splitting data, the tree looks for the feature and threshold that bring the biggest drop in impurity — because fewer mixed decisions = more confident predictions.
🧮 A Quick Example#
Say we have customer data for a telecom company:
Age |
Income |
Churned |
|---|---|---|
23 |
30K |
Yes |
42 |
90K |
No |
35 |
40K |
Yes |
50 |
100K |
No |
A Decision Tree might start with:
“Is Income > 60K?” If yes, most people didn’t churn → go right. If no, they probably churned → go left.
Boom 💥 — you’ve just made your first data-driven business policy.
🧠 Overfitting: The Tree That Knew Too Much#
Left unchecked, trees love to memorize the entire dataset — like that one intern who remembers every client’s birthday but forgets to send invoices.
This is called overfitting, and it happens when your tree becomes too deep, too specific, and too useless on new data.
So we prune it — ✂️ because in both gardening and machine learning, pruning keeps things healthy.
⚙️ In Python#
from sklearn.tree import DecisionTreeClassifier
tree = DecisionTreeClassifier(
criterion="gini",
max_depth=4,
random_state=42
)
tree.fit(X_train, y_train)
You can visualize it with:
from sklearn import tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
tree.plot_tree(tree, filled=True, feature_names=feature_names)
plt.show()
And voilà — a tree diagram that looks suspiciously like your thought process at 3 AM before deadlines.
🧩 Practice Time#
Try building a tree on your own:
Load a small dataset (e.g., Titanic survivors).
Train a DecisionTreeClassifier.
Visualize the tree.
Find out:
Which feature was split first?
How many leaves does your tree have?
Can you explain one decision path in plain English?
💡 Hint: The first split tells you what your model thinks is most important — like income, age, or whether the customer clicked “unsubscribe” three times this week.
🌳 Coming Up Next#
Up next: we assemble an entire forest — because if one tree is good, a hundred are tree-mendous 🌲😎
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from collections import Counter
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
# Entropy for classification
def entropy(y):
hist = np.bincount(y)
ps = hist / len(y)
return -np.sum([p * np.log2(p) for p in ps if p > 0])
# Variance for regression
def variance(y):
return np.var(y) if len(y) > 0 else 0
# Information Gain for classification
def information_gain(X, y, feature_idx, threshold):
parent_entropy = entropy(y)
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
return 0
n = len(y)
n_left, n_right = np.sum(left_mask), np.sum(right_mask)
child_entropy = (n_left / n) * entropy(y[left_mask]) + (n_right / n) * entropy(y[right_mask])
return parent_entropy - child_entropy
# Variance reduction for regression
def variance_reduction(X, y, feature_idx, threshold):
parent_var = variance(y)
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
return 0
n = len(y)
n_left, n_right = np.sum(left_mask), np.sum(right_mask)
child_var = (n_left / n) * variance(y[left_mask]) + (n_right / n) * variance(y[right_mask])
return parent_var - child_var
# Decision Tree Node
class Node:
def __init__(self, feature_idx=None, threshold=None, left=None, right=None, value=None):
self.feature_idx = feature_idx
self.threshold = threshold
self.left = left
self.right = right
self.value = value
# Decision Tree with path tracking
class DecisionTree:
def __init__(self, max_depth=3, min_samples_split=2, criterion='entropy'):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.criterion = criterion
self.root = None
self.boundaries = []
def fit(self, X, y):
self.root = self._grow_tree(X, y, depth=0)
def _grow_tree(self, X, y, depth):
n_samples, n_features = X.shape
if depth >= self.max_depth or n_samples < self.min_samples_split:
return Node(value=self._leaf_value(y))
best_gain = -1
best_idx, best_threshold = None, None
for feature_idx in range(n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
gain = (information_gain(X, y, feature_idx, threshold) if self.criterion in ['entropy']
else variance_reduction(X, y, feature_idx, threshold))
if gain > best_gain:
best_gain = gain
best_idx = feature_idx
best_threshold = threshold
if best_gain == 0:
return Node(value=self._leaf_value(y))
self.boundaries.append((best_idx, best_threshold, depth))
left_mask = X[:, best_idx] <= best_threshold
right_mask = ~left_mask
left = self._grow_tree(X[left_mask], y[left_mask], depth + 1)
right = self._grow_tree(X[right_mask], y[right_mask], depth + 1)
return Node(best_idx, best_threshold, left, right)
def _leaf_value(self, y):
return Counter(y).most_common(1)[0][0] if self.criterion in ['entropy'] else np.mean(y)
def predict(self, X):
return np.array([self._predict(x, self.root) for x in X])
def _predict(self, x, node):
if node.value is not None:
return node.value
if x[node.feature_idx] <= node.threshold:
return self._predict(x, node.left)
return self._predict(x, node.right)
def get_prediction_path(self, x):
path = []
node = self.root
while node.value is None:
path.append((node.feature_idx, node.threshold))
if x[node.feature_idx] <= node.threshold:
node = node.left
else:
node = node.right
path.append(('leaf', node.value))
return path
Classification and Regression Algorithms#
Classification (ID3-like)#
The DecisionTree class performs classification when criterion='entropy' (or gini). It uses Information Gain to select the optimal feature and threshold for splitting. The tree grows recursively until it hits a stopping condition (e.g., max_depth or min_samples_split). Leaf nodes return the majority class.
Example:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Load data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train and predict
clf = DecisionTree(max_depth=3, criterion='entropy')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(f"Classification Accuracy: {accuracy_score(y_test, y_pred):.2f}")
Classification Accuracy: 0.97
Explanation:
Splitting: The algorithm selects splits that maximize Information Gain, reducing entropy in child nodes.
Prediction: For a new sample, the tree is traversed from root to leaf based on feature thresholds, and the majority class at the leaf is returned.
Overfitting Control: Parameters like
max_depth=3andmin_samples_split=2prevent the tree from growing too complex, reducing overfitting.Performance: The accuracy score evaluates how well the tree generalizes to unseen data.
Regression (CART-like)#
For regression, set criterion='variance'. The tree uses variance reduction to choose splits, and leaf nodes return the mean of the target values in that region.
Example:
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error
# Generate regression data
X, y = make_regression(n_samples=100, n_features=4, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train and predict
reg = DecisionTree(max_depth=3, criterion='variance')
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
print(f"Regression MSE: {mean_squared_error(y_test, y_pred):.2f}")
Regression MSE: 2358.02
Explanation:
Splitting: Splits are chosen to maximize variance reduction, ensuring child nodes have more similar target values.
Prediction: The tree traverses to a leaf, returning the mean target value of the training samples in that leaf.
Overfitting Control: Limiting
max_depthand settingmin_samples_splitensures the tree doesn’t fit noise in the data.Performance: Mean Squared Error (MSE) measures the average squared difference between predicted and actual values, indicating regression quality.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from collections import Counter
from sklearn.datasets import make_moons, make_friedman1
from sklearn.model_selection import train_test_split
# [Keep all existing functions and classes identical until the animation function]
def animate_decision_tree(X, y, tree, test_point, task='classification'):
fig, ax = plt.subplots(figsize=(12, 8))
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300),
np.linspace(y_min, y_max, 300))
grid = np.c_[xx.ravel(), yy.ravel()]
# Predict for the grid
Z = tree.predict(grid)
Z = Z.reshape(xx.shape)
# Plot setup
title = f'Decision Tree {"Classifier" if task == "classification" else "Regressor"}'
cmap = plt.cm.RdYlBu if task == 'classification' else plt.cm.viridis
if task == 'classification':
cont = ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
else:
cont = ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
plt.colorbar(sc, label='Target Value', ax=ax)
test_scat = ax.scatter(test_point[0], test_point[1], c='yellow', s=200,
marker='*', edgecolor='black', linewidth=1.5, label='Test Point')
ax.set_xlabel('Feature 1', fontsize=12)
ax.set_ylabel('Feature 2', fontsize=12)
ax.legend(fontsize=10)
ax.set_title(title, fontsize=14, pad=20)
# Get prediction path
path = tree.get_prediction_path(test_point)
frame_descriptions = []
# Generate descriptions for each frame
for i, (feat_idx, thresh) in enumerate(path[:-1]):
direction = "left" if test_point[feat_idx] <= thresh else "right"
desc = f"Split {i+1}: x[{feat_idx}] ≤ {thresh:.2f}? ({direction} branch)"
frame_descriptions.append(desc)
frame_descriptions.append(f"Final prediction: {path[-1][1]:.2f}")
def update(frame):
ax.clear()
current_title = title + "\n" + frame_descriptions[min(frame, len(frame_descriptions)-1)]
ax.set_title(current_title, fontsize=14, pad=20)
# Recreate main elements
ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
ax.scatter(test_point[0], test_point[1], c='yellow', s=200,
marker='*', edgecolor='black', linewidth=1.5, label='Test Point')
ax.set_xlabel('Feature 1', fontsize=12)
ax.set_ylabel('Feature 2', fontsize=12)
ax.legend(fontsize=10)
# Plot boundaries up to current frame
for i, (feat_idx, thresh, depth) in enumerate(tree.boundaries[:frame+1]):
alpha = 0.8 - depth * 0.15
lw = 2 - depth * 0.3
if feat_idx == 0:
line = ax.axvline(thresh, color='navy', linestyle='--', alpha=alpha, linewidth=lw)
ax.text(thresh, y_max - depth*0.4 - 0.1*i, f'x[{feat_idx}] ≤ {thresh:.2f}',
fontsize=10, backgroundcolor='white',
verticalalignment='top', alpha=alpha,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
else:
line = ax.axhline(thresh, color='navy', linestyle='--', alpha=alpha, linewidth=lw)
ax.text(x_max - depth*0.4 - 0.1*i, thresh, f'x[{feat_idx}] ≤ {thresh:.2f}',
fontsize=10, backgroundcolor='white',
horizontalalignment='right', alpha=alpha,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
# Highlight current path step
if frame < len(path):
feat_idx, thresh = path[frame]
if feat_idx != 'leaf':
decision = "Yes" if test_point[feat_idx] <= thresh else "No"
color = 'limegreen' if decision == "Yes" else 'crimson'
if feat_idx == 0:
line = ax.axvline(thresh, color=color, linewidth=3, alpha=0.9)
ax.text(thresh, np.mean([y_min, y_max]),
f'Test: x[{feat_idx}]={test_point[feat_idx]:.2f}\n≤ {thresh:.2f}? {decision}',
fontsize=11, color='black', backgroundcolor='white',
verticalalignment='center', horizontalalignment='center',
bbox=dict(facecolor=color, alpha=0.3, edgecolor='none'))
else:
line = ax.axhline(thresh, color=color, linewidth=3, alpha=0.9)
ax.text(np.mean([x_min, x_max]), thresh,
f'Test: x[{feat_idx}]={test_point[feat_idx]:.2f}\n≤ {thresh:.2f}? {decision}',
fontsize=11, color='black', backgroundcolor='white',
verticalalignment='center', horizontalalignment='center',
bbox=dict(facecolor=color, alpha=0.3, edgecolor='none'))
else:
pred_value = thresh
if task == 'classification':
pred_text = f'Predicted Class: {int(pred_value)}'
else:
pred_text = f'Predicted Value: {pred_value:.2f}'
ax.text(0.5, 0.95, pred_text,
transform=ax.transAxes, fontsize=14,
color='white', backgroundcolor='green',
horizontalalignment='center', verticalalignment='center',
bbox=dict(facecolor='green', alpha=0.7, edgecolor='none'))
return ax,
ani = FuncAnimation(fig, update, frames=len(path) + 2, interval=2000, blit=False)
plt.close()
return HTML(ani.to_html5_video())
# Create more advanced datasets
# Classification - Moons dataset
X_moons, y_moons = make_moons(n_samples=300, noise=0.2, random_state=42)
X_train_clf, X_test_clf, y_train_clf, y_test_clf = train_test_split(X_moons, y_moons, test_size=0.2, random_state=42)
clf_tree = DecisionTree(max_depth=4, criterion='entropy')
clf_tree.fit(X_train_clf, y_train_clf)
test_point_clf = np.array([0.5, -0.3]) # Interesting point near decision boundary
# Regression - Non-linear dataset
np.random.seed(42)
X_reg = np.random.rand(300, 2) * 4 - 2
y_reg = np.sin(X_reg[:, 0] * 2) + np.cos(X_reg[:, 1] * 2) + np.random.normal(0, 0.2, 300)
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)
reg_tree = DecisionTree(max_depth=4, criterion='variance')
reg_tree.fit(X_train_reg, y_train_reg)
test_point_reg = np.array([-0.5, 1.2]) # Point in an interesting region
print("Classification Animation (Moons Dataset)")
display(animate_decision_tree(X_train_clf, y_train_clf, clf_tree, test_point_clf, task='classification'))
print("\nRegression Animation (Non-linear Dataset)")
display(animate_decision_tree(X_train_reg, y_train_reg, reg_tree, test_point_reg, task='regression'))
Classification Animation (Moons Dataset)
Regression Animation (Non-linear Dataset)
# Your code here