🌿 Decision Trees#

Welcome to the Decision Tree dojo, where data gets sliced, diced, and neatly organized into if–then–else statements — basically, the algorithmic version of your mom deciding what to cook:

“If it’s raining → make pakoras ☔ Else if it’s sunny → ice cream 🍦 Else → leftovers 😅”

That, my friend, is a Decision Tree.


🌱 The Core Idea#

A Decision Tree works by asking a series of binary questions that split your data into smaller and smaller groups — until each group is so pure it could join a yoga retreat.

Example:

“Is income > ₹60,000?” “Yes? → Go right 🌳” “No? → Go left 🍂”

Each split reduces uncertainty — kind of like narrowing down who ate the last slice of pizza at the office.


🎯 The Goal: Minimize Impurity#

Decision Trees are obsessed with purity — not moral, but mathematical purity. They use measures like:

Metric

Meaning

Gini Impurity

“How mixed-up is this node?” (0 = perfectly pure)

Entropy

Borrowed from physics — aka “How chaotic is this node?”

When splitting data, the tree looks for the feature and threshold that bring the biggest drop in impurity — because fewer mixed decisions = more confident predictions.


🧮 A Quick Example#

Say we have customer data for a telecom company:

Age

Income

Churned

23

30K

Yes

42

90K

No

35

40K

Yes

50

100K

No

A Decision Tree might start with:

“Is Income > 60K?” If yes, most people didn’t churn → go right. If no, they probably churned → go left.

Boom 💥 — you’ve just made your first data-driven business policy.


🧠 Overfitting: The Tree That Knew Too Much#

Left unchecked, trees love to memorize the entire dataset — like that one intern who remembers every client’s birthday but forgets to send invoices.

This is called overfitting, and it happens when your tree becomes too deep, too specific, and too useless on new data.

So we prune it — ✂️ because in both gardening and machine learning, pruning keeps things healthy.


⚙️ In Python#

from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(
    criterion="gini",
    max_depth=4,
    random_state=42
)
tree.fit(X_train, y_train)

You can visualize it with:

from sklearn import tree
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
tree.plot_tree(tree, filled=True, feature_names=feature_names)
plt.show()

And voilà — a tree diagram that looks suspiciously like your thought process at 3 AM before deadlines.


🧩 Practice Time#

Try building a tree on your own:

  1. Load a small dataset (e.g., Titanic survivors).

  2. Train a DecisionTreeClassifier.

  3. Visualize the tree.

  4. Find out:

    • Which feature was split first?

    • How many leaves does your tree have?

    • Can you explain one decision path in plain English?

💡 Hint: The first split tells you what your model thinks is most important — like income, age, or whether the customer clicked “unsubscribe” three times this week.


🌳 Coming Up Next#

Up next: we assemble an entire forest — because if one tree is good, a hundred are tree-mendous 🌲😎

👉 Next: Bagging, RF, XGBoost

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from collections import Counter
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split

# Entropy for classification
def entropy(y):
    hist = np.bincount(y)
    ps = hist / len(y)
    return -np.sum([p * np.log2(p) for p in ps if p > 0])

# Variance for regression
def variance(y):
    return np.var(y) if len(y) > 0 else 0

# Information Gain for classification
def information_gain(X, y, feature_idx, threshold):
    parent_entropy = entropy(y)
    left_mask = X[:, feature_idx] <= threshold
    right_mask = ~left_mask
    if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
        return 0
    n = len(y)
    n_left, n_right = np.sum(left_mask), np.sum(right_mask)
    child_entropy = (n_left / n) * entropy(y[left_mask]) + (n_right / n) * entropy(y[right_mask])
    return parent_entropy - child_entropy

# Variance reduction for regression
def variance_reduction(X, y, feature_idx, threshold):
    parent_var = variance(y)
    left_mask = X[:, feature_idx] <= threshold
    right_mask = ~left_mask
    if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
        return 0
    n = len(y)
    n_left, n_right = np.sum(left_mask), np.sum(right_mask)
    child_var = (n_left / n) * variance(y[left_mask]) + (n_right / n) * variance(y[right_mask])
    return parent_var - child_var

# Decision Tree Node
class Node:
    def __init__(self, feature_idx=None, threshold=None, left=None, right=None, value=None):
        self.feature_idx = feature_idx
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

# Decision Tree with path tracking
class DecisionTree:
    def __init__(self, max_depth=3, min_samples_split=2, criterion='entropy'):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.criterion = criterion
        self.root = None
        self.boundaries = []

    def fit(self, X, y):
        self.root = self._grow_tree(X, y, depth=0)

    def _grow_tree(self, X, y, depth):
        n_samples, n_features = X.shape
        if depth >= self.max_depth or n_samples < self.min_samples_split:
            return Node(value=self._leaf_value(y))
        
        best_gain = -1
        best_idx, best_threshold = None, None
        
        for feature_idx in range(n_features):
            thresholds = np.unique(X[:, feature_idx])
            for threshold in thresholds:
                gain = (information_gain(X, y, feature_idx, threshold) if self.criterion in ['entropy']
                        else variance_reduction(X, y, feature_idx, threshold))
                if gain > best_gain:
                    best_gain = gain
                    best_idx = feature_idx
                    best_threshold = threshold
        
        if best_gain == 0:
            return Node(value=self._leaf_value(y))
        
        self.boundaries.append((best_idx, best_threshold, depth))
        left_mask = X[:, best_idx] <= best_threshold
        right_mask = ~left_mask
        left = self._grow_tree(X[left_mask], y[left_mask], depth + 1)
        right = self._grow_tree(X[right_mask], y[right_mask], depth + 1)
        return Node(best_idx, best_threshold, left, right)

    def _leaf_value(self, y):
        return Counter(y).most_common(1)[0][0] if self.criterion in ['entropy'] else np.mean(y)

    def predict(self, X):
        return np.array([self._predict(x, self.root) for x in X])

    def _predict(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature_idx] <= node.threshold:
            return self._predict(x, node.left)
        return self._predict(x, node.right)

    def get_prediction_path(self, x):
        path = []
        node = self.root
        while node.value is None:
            path.append((node.feature_idx, node.threshold))
            if x[node.feature_idx] <= node.threshold:
                node = node.left
            else:
                node = node.right
        path.append(('leaf', node.value))
        return path

Classification and Regression Algorithms#

Classification (ID3-like)#

The DecisionTree class performs classification when criterion='entropy' (or gini). It uses Information Gain to select the optimal feature and threshold for splitting. The tree grows recursively until it hits a stopping condition (e.g., max_depth or min_samples_split). Leaf nodes return the majority class.

Example:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train and predict
clf = DecisionTree(max_depth=3, criterion='entropy')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(f"Classification Accuracy: {accuracy_score(y_test, y_pred):.2f}")
Classification Accuracy: 0.97

Explanation:

  • Splitting: The algorithm selects splits that maximize Information Gain, reducing entropy in child nodes.

  • Prediction: For a new sample, the tree is traversed from root to leaf based on feature thresholds, and the majority class at the leaf is returned.

  • Overfitting Control: Parameters like max_depth=3 and min_samples_split=2 prevent the tree from growing too complex, reducing overfitting.

  • Performance: The accuracy score evaluates how well the tree generalizes to unseen data.

Regression (CART-like)#

For regression, set criterion='variance'. The tree uses variance reduction to choose splits, and leaf nodes return the mean of the target values in that region.

Example:

from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error

# Generate regression data
X, y = make_regression(n_samples=100, n_features=4, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train and predict
reg = DecisionTree(max_depth=3, criterion='variance')
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
print(f"Regression MSE: {mean_squared_error(y_test, y_pred):.2f}")
Regression MSE: 2358.02

Explanation:

  • Splitting: Splits are chosen to maximize variance reduction, ensuring child nodes have more similar target values.

  • Prediction: The tree traverses to a leaf, returning the mean target value of the training samples in that leaf.

  • Overfitting Control: Limiting max_depth and setting min_samples_split ensures the tree doesn’t fit noise in the data.

  • Performance: Mean Squared Error (MSE) measures the average squared difference between predicted and actual values, indicating regression quality.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from collections import Counter
from sklearn.datasets import make_moons, make_friedman1
from sklearn.model_selection import train_test_split

# [Keep all existing functions and classes identical until the animation function]

def animate_decision_tree(X, y, tree, test_point, task='classification'):
    fig, ax = plt.subplots(figsize=(12, 8))
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300), 
                         np.linspace(y_min, y_max, 300))
    grid = np.c_[xx.ravel(), yy.ravel()]
    
    # Predict for the grid
    Z = tree.predict(grid)
    Z = Z.reshape(xx.shape)
    
    # Plot setup
    title = f'Decision Tree {"Classifier" if task == "classification" else "Regressor"}'
    cmap = plt.cm.RdYlBu if task == 'classification' else plt.cm.viridis
    
    if task == 'classification':
        cont = ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
        sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
    else:
        cont = ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
        sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
        plt.colorbar(sc, label='Target Value', ax=ax)
    
    test_scat = ax.scatter(test_point[0], test_point[1], c='yellow', s=200, 
                          marker='*', edgecolor='black', linewidth=1.5, label='Test Point')
    ax.set_xlabel('Feature 1', fontsize=12)
    ax.set_ylabel('Feature 2', fontsize=12)
    ax.legend(fontsize=10)
    ax.set_title(title, fontsize=14, pad=20)

    # Get prediction path
    path = tree.get_prediction_path(test_point)
    frame_descriptions = []
    
    # Generate descriptions for each frame
    for i, (feat_idx, thresh) in enumerate(path[:-1]):
        direction = "left" if test_point[feat_idx] <= thresh else "right"
        desc = f"Split {i+1}: x[{feat_idx}] ≤ {thresh:.2f}? ({direction} branch)"
        frame_descriptions.append(desc)
    frame_descriptions.append(f"Final prediction: {path[-1][1]:.2f}")

    def update(frame):
        ax.clear()
        current_title = title + "\n" + frame_descriptions[min(frame, len(frame_descriptions)-1)]
        ax.set_title(current_title, fontsize=14, pad=20)
        
        # Recreate main elements
        ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.4, levels=20)
        ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=60)
        ax.scatter(test_point[0], test_point[1], c='yellow', s=200, 
                  marker='*', edgecolor='black', linewidth=1.5, label='Test Point')
        ax.set_xlabel('Feature 1', fontsize=12)
        ax.set_ylabel('Feature 2', fontsize=12)
        ax.legend(fontsize=10)

        # Plot boundaries up to current frame
        for i, (feat_idx, thresh, depth) in enumerate(tree.boundaries[:frame+1]):
            alpha = 0.8 - depth * 0.15
            lw = 2 - depth * 0.3
            if feat_idx == 0:
                line = ax.axvline(thresh, color='navy', linestyle='--', alpha=alpha, linewidth=lw)
                ax.text(thresh, y_max - depth*0.4 - 0.1*i, f'x[{feat_idx}] ≤ {thresh:.2f}',
                        fontsize=10, backgroundcolor='white', 
                        verticalalignment='top', alpha=alpha,
                        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
            else:
                line = ax.axhline(thresh, color='navy', linestyle='--', alpha=alpha, linewidth=lw)
                ax.text(x_max - depth*0.4 - 0.1*i, thresh, f'x[{feat_idx}] ≤ {thresh:.2f}',
                        fontsize=10, backgroundcolor='white',
                        horizontalalignment='right', alpha=alpha,
                        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

        # Highlight current path step
        if frame < len(path):
            feat_idx, thresh = path[frame]
            if feat_idx != 'leaf':
                decision = "Yes" if test_point[feat_idx] <= thresh else "No"
                color = 'limegreen' if decision == "Yes" else 'crimson'
                
                if feat_idx == 0:
                    line = ax.axvline(thresh, color=color, linewidth=3, alpha=0.9)
                    ax.text(thresh, np.mean([y_min, y_max]), 
                            f'Test: x[{feat_idx}]={test_point[feat_idx]:.2f}\n≤ {thresh:.2f}? {decision}',
                            fontsize=11, color='black', backgroundcolor='white',
                            verticalalignment='center', horizontalalignment='center',
                            bbox=dict(facecolor=color, alpha=0.3, edgecolor='none'))
                else:
                    line = ax.axhline(thresh, color=color, linewidth=3, alpha=0.9)
                    ax.text(np.mean([x_min, x_max]), thresh, 
                            f'Test: x[{feat_idx}]={test_point[feat_idx]:.2f}\n≤ {thresh:.2f}? {decision}',
                            fontsize=11, color='black', backgroundcolor='white',
                            verticalalignment='center', horizontalalignment='center',
                            bbox=dict(facecolor=color, alpha=0.3, edgecolor='none'))
            else:
                pred_value = thresh
                if task == 'classification':
                    pred_text = f'Predicted Class: {int(pred_value)}'
                else:
                    pred_text = f'Predicted Value: {pred_value:.2f}'
                
                ax.text(0.5, 0.95, pred_text, 
                       transform=ax.transAxes, fontsize=14,
                       color='white', backgroundcolor='green',
                       horizontalalignment='center', verticalalignment='center',
                       bbox=dict(facecolor='green', alpha=0.7, edgecolor='none'))

        return ax,

    ani = FuncAnimation(fig, update, frames=len(path) + 2, interval=2000, blit=False)
    plt.close()
    return HTML(ani.to_html5_video())

# Create more advanced datasets
# Classification - Moons dataset
X_moons, y_moons = make_moons(n_samples=300, noise=0.2, random_state=42)
X_train_clf, X_test_clf, y_train_clf, y_test_clf = train_test_split(X_moons, y_moons, test_size=0.2, random_state=42)
clf_tree = DecisionTree(max_depth=4, criterion='entropy')
clf_tree.fit(X_train_clf, y_train_clf)
test_point_clf = np.array([0.5, -0.3])  # Interesting point near decision boundary

# Regression - Non-linear dataset
np.random.seed(42)
X_reg = np.random.rand(300, 2) * 4 - 2
y_reg = np.sin(X_reg[:, 0] * 2) + np.cos(X_reg[:, 1] * 2) + np.random.normal(0, 0.2, 300)
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)
reg_tree = DecisionTree(max_depth=4, criterion='variance')
reg_tree.fit(X_train_reg, y_train_reg)
test_point_reg = np.array([-0.5, 1.2])  # Point in an interesting region

print("Classification Animation (Moons Dataset)")
display(animate_decision_tree(X_train_clf, y_train_clf, clf_tree, test_point_clf, task='classification'))

print("\nRegression Animation (Non-linear Dataset)")
display(animate_decision_tree(X_train_reg, y_train_reg, reg_tree, test_point_reg, task='regression'))
Classification Animation (Moons Dataset)
Regression Animation (Non-linear Dataset)
# Your code here