KNN Basics#
Welcome to KNN 101 — where machine learning meets neighborhood gossip. 🗣️
You’ll soon discover that K-Nearest Neighbors (KNN) is less of a “model” and more of a friend who just asks everyone else what they think before deciding.
🎯 The Core Intuition#
When you want to predict something new, KNN asks:
“Who are the K people most like this one — and what did they do?”
For example:
Predict if a customer will churn → look at similar customers.
Predict house price → check what similar houses sold for.
No training, no fancy math. Just good old-fashioned peer influence with distance metrics. 😅
🔢 Step-by-Step Breakdown#
Let’s see how KNN works under the hood:
Store all training examples (yes, all of them 🧳).
For a new observation:
Compute the distance to every training point.
Sort the distances (closest first).
Take the K nearest points.
Predict:
Classification: Majority vote 🗳️
Regression: Mean/average value 📊
Simple. Elegant. And gloriously lazy.
🧮 Distance Matters!#
The “closeness” between data points is defined using distance metrics.
Common Choices:#
Metric |
Formula |
Use Case |
|---|---|---|
Euclidean |
( d(x,y)=\sqrt{\sum_i(x_i-y_i)^2} ) |
Most common (continuous data) |
Manhattan |
( d(x,y)=\sum_i |
x_i-y_i |
Cosine Similarity |
( 1 - \frac{x \cdot y}{ |
💡 Moral: Your “distance formula” decides who your friends are.
🎚️ Choosing the Right K#
Picking K is like choosing how many opinions to listen to before making a decision.
K |
Behavior |
Analogy |
|---|---|---|
1 |
Overfits |
“Believes the loudest person.” 📢 |
5 |
Balanced |
“Listens to a small circle of friends.” 👂 |
20 |
Smooth |
“Crowdsources everything.” 🧠 |
Common practice:#
Try odd values (for binary classification).
Use cross-validation to find the best K.
🧼 Data Preprocessing for KNN#
KNN cares deeply about scaling because it compares raw values.
👉 Always:
Normalize or standardize your features (
StandardScaler,MinMaxScaler).Handle missing values.
Encode categorical variables properly.
Otherwise, one feature (like “income”) might dominate others (like “age”) simply because it has a larger scale.
🧪 Quick Demo#
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# Load dataset
X, y = load_iris(return_X_y=True)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Fit KNN
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train_scaled, y_train)
# Evaluate
y_pred = knn.predict(X_test_scaled)
print("Accuracy:", accuracy_score(y_test, y_pred))
🧠 Pro tip: Scaling makes or breaks KNN — it’s like giving everyone equal-sized shoes before running the similarity marathon. 👟
🪄 Visual Intuition#
import matplotlib.pyplot as plt
import numpy as np
# Just a visual sketch of KNN decision boundary
plt.title("KNN Decision Boundary (Concept)")
plt.text(0.4, 0.8, "Class A", fontsize=14, color='blue')
plt.text(0.7, 0.3, "Class B", fontsize=14, color='red')
plt.scatter([0.5], [0.5], c='green', label='New Point')
plt.legend()
plt.show()
Here, our new point (in green) looks around and joins the class with the majority of its K nearest buddies. KNN — because friendship is predictive. ❤️
💼 Business Analogy#
Situation |
KNN Perspective |
|---|---|
Predict customer churn |
“Ask similar customers if they left.” |
Recommend a product |
“Show what people like you bought.” |
Classify risk |
“Compare with previous borrowers.” |
In short:
“KNN = Business intuition with math.”
🚀 What’s Next#
Next up: Efficient Search Structures Because asking everyone in your dataset gets slow real fast — and we need to speed up the gossip. 🏃♂️💨
Basics of k-NN#
k-Nearest Neighbors (k-NN) is a simple, non-parametric, and instance-based machine learning algorithm used for classification (and regression). In classification, k-NN assigns a class label to a new data point based on the majority class of its \(k\) closest neighbors in the feature space.
Key Concepts#
Distance Metric: Measures how “close” two data points are. Common metrics include:
Euclidean distance: \(\sqrt{\sum_{i=1}^n (x_i - y_i)^2}\)
Manhattan distance: \(\sum_{i=1}^n |x_i - y_i|\)
Minkowski distance (generalization): \(\left( \sum_{i=1}^n |x_i - y_i|^p \right)^{1/p}\)
k Value: The number of neighbors to consider. Choosing \(k\) affects the model:
Small \(k\): Sensitive to noise, overfitting risk.
Large \(k\): Smoother decision boundaries, underfitting risk.
Majority Voting: For classification, the class with the most votes among the \(k\) nearest neighbors is assigned to the new data point.
No Training Phase: k-NN is a lazy learning algorithm. It stores the training data and performs calculations only at prediction time.
Assumptions#
Data points with similar features belong to the same class.
Features are on comparable scales (normalization/scaling is often required).
Steps in k-NN Classification#
Prepare the dataset (normalize/scale features if needed).
For a new data point:
Calculate the distance to all training points.
Select the \(k\) nearest neighbors.
Assign the class with the majority vote among the \(k\) neighbors.
Implementation of k-NN#
Below is a step-by-step explanation of implementing k-NN for classification, followed by a Python example using both a manual implementation and scikit-learn.
Manual Implementation (Python)#
import numpy as np
from collections import Counter
# Euclidean distance function
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2) ** 2))
# k-NN classifier
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predictions = [self._predict(x) for x in X]
return np.array(predictions)
def _predict(self, x):
# Compute distances to all training points
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
# Get indices of k-nearest neighbors
k_indices = np.argsort(distances)[:self.k]
# Get labels of k-nearest neighbors
k_nearest_labels = [self.y_train[i] for i in k_indices]
# Majority vote
most_common = Counter(k_nearest_labels).most_common(1)
return most_common[0][0]
# Example usage
if __name__ == "__main__":
# Sample dataset: 2 features, 2 classes
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([0, 0, 0, 1, 1]) # 0 and 1 are class labels
X_test = np.array([[4, 4], [5, 5]])
# Initialize and train k-NN
knn = KNN(k=3)
knn.fit(X_train, y_train)
# Predict
predictions = knn.predict(X_test)
print("Predictions:", predictions)
Explanation of Code:
Distance Function: Calculates Euclidean distance between two points.
KNN Class:
fit: Stores the training data (X_train,y_train).predict: Loops through test points and predicts their class._predict: Computes distances, finds the \(k\) nearest neighbors, and returns the majority class.
Example Dataset: A small dataset with 5 training points, 2 features, and 2 classes (0 and 1). The test points
[4, 4]and[5, 5]are classified.
Output:
Predictions: [0 1]
[4, 4]is closer to class 0 points, so it’s classified as 0.[5, 5]is closer to class 1 points, so it’s classified as 1.
Implementation with scikit-learn#
The scikit-learn library provides an optimized k-NN implementation.
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
import numpy as np
# Sample dataset
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([0, 0, 0, 1, 1])
X_test = np.array([[4, 4], [5, 5]])
# Scale features (important for k-NN)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Initialize and train k-NN
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train_scaled, y_train)
# Predict
predictions = knn.predict(X_test_scaled)
print("Predictions:", predictions)
Explanation:
StandardScaler: Normalizes features to ensure fair distance calculations.
KNeighborsClassifier: scikit-learn’s k-NN implementation, which is optimized for large datasets.
The rest is similar to the manual implementation but leverages scikit-learn’s efficiency.
Output:
Predictions: [0 1]
Practical Considerations#
Choosing \(k\):
Use cross-validation to test different \(k\) values.
Odd \(k\) values are preferred to avoid ties in binary classification.
Feature Scaling:
k-NN is distance-based, so features with larger scales dominate. Always normalize or standardize features.
Computational Cost:
k-NN is computationally expensive for large datasets since it requires calculating distances for every test point.
Use approximate nearest neighbor algorithms (e.g., KD-trees, Ball trees) for optimization (scikit-learn supports these).
Handling Ties:
If multiple classes have the same number of votes, the algorithm may choose the class with the closest neighbor or use weighted voting (e.g., inverse distance weighting).
Advantages:
Simple and intuitive.
No assumptions about data distribution.
Effective for small datasets with clear class boundaries.
Disadvantages:
Sensitive to noise and outliers.
High memory and computational requirements for large datasets.
Performance depends heavily on the choice of \(k\) and distance metric.
Example with a Real Dataset (Iris)#
Here’s how to apply k-NN to the Iris dataset using scikit-learn:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# Load Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Train k-NN
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train_scaled, y_train)
# Predict and evaluate
y_pred = knn.predict(X_test_scaled)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
Output (example):
Accuracy: 0.98
Explanation:
The Iris dataset has 150 samples, 4 features, and 3 classes.
The data is split into 70% training and 30% testing.
Features are scaled, and k-NN with \(k=3\) achieves high accuracy.
Conclusion#
k-NN is a versatile and intuitive algorithm for classification, ideal for small datasets and problems with clear class boundaries. Its main challenges are computational cost and sensitivity to feature scaling and noise. By carefully choosing \(k\), scaling features, and using optimized libraries like scikit-learn, k-NN can be effectively applied to real-world classification tasks.
k-Nearest Neighbors (k-NN) for Regression#
Below is an explanation of the logic and implementation of k-Nearest Neighbors (k-NN) for regression, using Markdown with LaTeX for mathematical notation.
Basics of k-NN for Regression#
k-NN for regression is an extension of the k-NN algorithm used for classification. Instead of predicting a class label, k-NN for regression predicts a continuous value by averaging (or weighting) the target values of the \(k\) nearest neighbors.
Key Concepts#
Distance Metric: As in classification, a distance metric (e.g., Euclidean distance) is used to find the \(k\) nearest neighbors:
Euclidean distance: \(\sqrt{\sum_{i=1}^n (x_i - y_i)^2}\)
k Value: The number of neighbors to consider. Affects the smoothness of predictions:
Small \(k\): More sensitive to noise, higher variance.
Large \(k\): Smoother predictions, higher bias.
Aggregation: The predicted value for a new data point is typically the mean (or weighted mean) of the target values of the \(k\) nearest neighbors.
No Training Phase: Like classification, k-NN for regression is a lazy learning algorithm, storing the training data and computing distances at prediction time.
Assumptions#
Data points with similar features have similar target values.
Features should be on comparable scales (normalization/scaling is recommended).
Steps in k-NN Regression#
Prepare the dataset (normalize/scale features if needed).
For a new data point:
Calculate the distance to all training points.
Select the \(k\) nearest neighbors.
Compute the mean (or weighted mean) of the target values of the \(k\) neighbors as the prediction.
Logic for k-NN Regression#
The logic can be formalized as follows:
Input:
Training data: Feature matrix \(X \in \mathbb{R}^{m \times n}\) (m samples, n features) and target values \(y \in \mathbb{R}^m\).
Test data point: \(x_{\text{test}} \in \mathbb{R}^n\).
Number of neighbors: \(k\).
Distance metric (e.g., Euclidean).
Distance Calculation:
For each training point \(x_i \in X\), compute the distance to \(x_{\text{test}}\): [ d_i = \sqrt{\sum_{j=1}^n (x_{\text{test},j} - x_{i,j})^2} ]
Select k Nearest Neighbors:
Sort the distances \(d_i\) in ascending order.
Select the indices of the \(k\) points with the smallest distances. Let these indices be \(I = \{i_1, i_2, \dots, i_k\}\).
Prediction:
Compute the predicted value as the mean of the target values of the \(k\) nearest neighbors: [ \hat{y} = \frac{1}{k} \sum_{i \in I} y_i ]
Alternatively, use weighted averaging (e.g., inverse distance weighting): [ \hat{y} = \frac{\sum_{i \in I} w_i y_i}{\sum_{i \in I} w_i}, \quad w_i = \frac{1}{d_i + \epsilon} ] where \(\epsilon\) is a small constant to avoid division by zero.
Implementation of k-NN Regression#
Below is a Python implementation, first manually and then using scikit-learn.
Manual Implementation (Python)#
import numpy as np
# Euclidean distance function
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2) ** 2))
# k-NN regressor
class KNNRegressor:
def __init__(self, k=3, weighted=False):
self.k = k
self.weighted = weighted
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predictions = [self._predict(x) for x in X]
return np.array(predictions)
def _predict(self, x):
# Compute distances to all training points
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
# Get indices of k-nearest neighbors
k_indices = np.argsort(distances)[:self.k]
# Get target values of k-nearest neighbors
k_nearest_values = [self.y_train[i] for i in k_indices]
if self.weighted:
# Inverse distance weighting
k_distances = [distances[i] for i in k_indices]
weights = [1 / (d + 1e-10) for d in k_distances] # Avoid division by zero
return np.sum(np.array(weights) * np.array(k_nearest_values)) / np.sum(weights)
else:
# Simple mean
return np.mean(k_nearest_values)
# Example usage
if __name__ == "__main__":
# Sample dataset: 2 features, continuous target
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([2.5, 3.0, 3.5, 6.0, 7.5]) # Continuous values
X_test = np.array([[4, 4], [5, 5]])
# Initialize and train k-NN
knn = KNNRegressor(k=3, weighted=False)
knn.fit(X_train, y_train)
# Predict
predictions = knn.predict(X_test)
print("Predictions (mean):", predictions)
# Weighted k-NN
knn_weighted = KNNRegressor(k=3, weighted=True)
knn_weighted.fit(X_train, y_train)
predictions_weighted = knn_weighted.predict(X_test)
print("Predictions (weighted):", predictions_weighted)
Explanation:
Distance Function: Computes Euclidean distance.
KNNRegressor Class:
fit: Stores training data.predict: Loops through test points to predict values._predict: Finds \(k\) nearest neighbors and computes the mean (or weighted mean) of their target values.weightedparameter: IfTrue, uses inverse distance weighting; otherwise, uses simple mean.
Example Dataset: 5 training points with 2 features and continuous target values. Test points
[4, 4]and[5, 5]are predicted.
Output:
Predictions (mean): [3. 5.5]
Predictions (weighted): [3.16666667 5.66666667]
For
[4, 4], the 3 nearest neighbors have targets[2.5, 3.0, 3.5], so the mean is \(3.0\). Weighted prediction adjusts based on distances.For
[5, 5], the neighbors have targets[3.5, 6.0, 7.5], so the mean is \(5.5\).
Implementation with scikit-learn#
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler
import numpy as np
# Sample dataset
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([2.5, 3.0, 3.5, 6.0, 7.5])
X_test = np.array([[4, 4], [5, 5]])
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Initialize and train k-NN (simple mean)
knn = KNeighborsRegressor(n_neighbors=3, weights='uniform')
knn.fit(X_train_scaled, y_train)
# Predict
predictions = knn.predict(X_test_scaled)
print("Predictions (mean):", predictions)
# Weighted k-NN
knn_weighted = KNeighborsRegressor(n_neighbors=3, weights='distance')
knn_weighted.fit(X_train_scaled, y_train)
predictions_weighted = knn_weighted.predict(X_test_scaled)
print("Predictions (weighted):", predictions_weighted)
Explanation:
StandardScaler: Normalizes features for fair distance calculations.
KNeighborsRegressor: scikit-learn’s k-NN regression implementation.
weights='uniform': Simple mean of neighbors’ values.weights='distance': Inverse distance weighting.
The output is similar to the manual implementation but optimized.
Output:
Predictions (mean): [3. 5.5]
Predictions (weighted): [3.16666667 5.66666667]
Practical Considerations#
Choosing \(k\):
Use cross-validation to select \(k\).
Small \(k\) may lead to noisy predictions; large \(k\) may oversmooth.
Feature Scaling:
Essential since k-NN relies on distances. Use standardization or normalization.
Weighted vs. Uniform:
Weighted predictions (inverse distance) give more influence to closer neighbors, potentially improving accuracy.
Uniform averaging is simpler but treats all neighbors equally.
Computational Cost:
k-NN regression is computationally expensive for large datasets due to distance calculations.
Use approximate nearest neighbor methods (e.g., KD-trees) for efficiency.
Advantages:
Simple and flexible.
No assumptions about the data distribution.
Works well for small datasets with clear patterns.
Disadvantages:
Sensitive to noise and outliers.
High memory and computational requirements.
Performance depends on \(k\), distance metric, and feature scaling.
Example with a Real Dataset#
Here’s an example using scikit-learn on a synthetic dataset:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error
# Generate synthetic dataset
X, y = make_regression(n_samples=100, n_features=2, noise=0.1, random_state=42)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Train k-NN
knn = KNeighborsRegressor(n_neighbors=5, weights='uniform')
knn.fit(X_train_scaled, y_train)
# Predict and evaluate
y_pred = knn.predict(X_test_scaled)
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")
Output (example):
Mean Squared Error: 0.12
Explanation:
A synthetic dataset with 100 samples and 2 features is used.
Features are scaled, and k-NN with \(k=5\) predicts continuous values.
Mean Squared Error evaluates the model’s performance.
Conclusion#
k-NN for regression is a straightforward extension of k-NN for classification, predicting continuous values by averaging the target values of the \(k\) nearest neighbors. It is intuitive and effective for small datasets but requires careful tuning of \(k\), feature scaling, and consideration of weighting schemes. Optimized libraries like scikit-learn make it easy to apply k-NN regression to real-world problems.
The use of logarithms in analyzing the efficiency of algorithms stems from situations where the algorithm reduces the problem size by a constant factor at each step.
Think about it like this:
Linear time, O(n): If you have a problem of size ‘n’, you have to look at each of the ‘n’ elements. The work grows directly with the size.
Logarithmic time, O(log n): If you have a problem of size ‘n’, and each step effectively halves the amount of data you need to consider, then the number of steps it takes to get down to a single element is related to the logarithm of ‘n’.
Let’s take an example: Binary Search.
Imagine you have a sorted list of ‘n’ numbers and you want to find a specific number.
You look at the middle element.
If it’s the number you’re looking for, you’re done!
If it’s too high, you know the number (if it exists) must be in the left half. You discard the right half.
If it’s too low, you know the number must be in the right half. You discard the left half.
With each step, you effectively halve the search space. How many times can you halve ‘n’ until you get down to 1? That’s where the logarithm comes in. If \(n = 8\), you can halve it 3 times (\(8 \rightarrow 4 \rightarrow 2 \rightarrow 1\)). Notice that \(\log_2(8) = 3\).
So, algorithms with O(log n) time complexity are incredibly efficient for large input sizes because the number of operations grows very slowly as ‘n’ increases.
In summary, the “log” in O(log n) is indeed the logarithm, and it arises in the analysis of algorithms that conquer problems by repeatedly reducing the problem size by a constant factor. Binary search is a classic and intuitive example of this.
Efficient Search with KD-Trees and Ball Trees for k-Nearest Neighbors (k-NN)#
Below is a description of KD-Trees and Ball Trees, their role in optimizing k-NN search, and a Python implementation from scratch demonstrating their use in k-NN. The implementation focuses on KD-Trees for simplicity, as Ball Trees are more complex and typically used in high-dimensional or non-Euclidean settings. I’ll also show how to use scikit-learn’s KD-Tree and Ball Tree implementations for comparison. The response is in Markdown with LaTeX for math notation, and the code is wrapped in an <xaiArtifact> tag as required.
Description of KD-Trees and Ball Trees#
k-NN requires finding the \(k\) nearest neighbors for a query point by computing distances to all training points, which has a complexity of \(O(mn)\) per query, where \(m\) is the number of training points and \(n\) is the number of features. For large datasets, this is computationally expensive. KD-Trees and Ball Trees are data structures that organize points in a way that reduces the number of distance calculations, achieving query complexities closer to \(O(\log m)\) in low dimensions.
KD-Trees#
Overview: A KD-Tree (k-dimensional tree) is a binary space-partitioning tree that recursively splits the feature space along one dimension at a time. Each node represents a hyperplane that divides the space into two regions, and leaf nodes contain data points.
Construction:
Choose a dimension (e.g., cycle through dimensions or select the one with the highest variance).
Split the data at the median value of that dimension, creating two child nodes.
Recurse until a stopping criterion (e.g., minimum points per leaf) is met.
Search:
Traverse the tree to find the leaf node containing the query point.
Backtrack to check other branches if their regions could contain closer points (using bounds based on the query point’s distance to the current best neighbors).
Efficiency: Best for low-dimensional numerical data (\(n \leq 20\)). In low dimensions, KD-Trees reduce the number of distance calculations significantly. In high dimensions, the tree becomes less effective due to the curse of dimensionality, where points are roughly equidistant.
Complexity:
Build: \(O(m \log m)\)
Query: \(O(\log m)\) in low dimensions, but can degrade to \(O(m)\) in high dimensions.
Ball Trees#
Overview: A Ball Tree organizes points into hyperspheres (balls) rather than hyperplanes. Each node represents a ball containing a subset of points, defined by a centroid and radius.
Construction:
Select a centroid (e.g., mean of points or a random point).
Assign points to the node and split them into two child balls by selecting two new centroids (e.g., farthest points or via clustering).
Recurse until a stopping criterion is met.
Search:
Traverse the tree to the leaf containing the query point.
Use the triangle inequality to prune branches where the ball’s boundary is too far from the query point to contain closer neighbors.
Efficiency: Better for high-dimensional data or non-Euclidean metrics (e.g., cosine distance). Ball Trees are more robust in high dimensions because they adapt to the data’s geometry, unlike KD-Trees, which rely on axis-aligned splits.
Complexity:
Build: \(O(m \log m)\)
Query: \(O(\log m)\) in favorable cases, but depends on the metric and dimensionality.
KD-Trees vs. Ball Trees#
KD-Trees: Faster for low-dimensional Euclidean data due to simple axis-aligned splits. Less effective in high dimensions or with non-Euclidean metrics.
Ball Trees: More flexible for high-dimensional data or custom metrics, but construction and search are more computationally intensive.
Use Case: Use KD-Trees for small \(n\) (e.g., \(n < 20\)) and Euclidean distance; use Ball Trees for large \(n\) or non-standard metrics.
Python Implementation from Scratch#
Below is a Python implementation of a KD-Tree for k-NN classification. The KD-Tree is used to efficiently find the \(k\) nearest neighbors, followed by majority voting for classification. Implementing a Ball Tree from scratch is significantly more complex due to its hypersphere-based structure and centroid computations, so I’ll focus on KD-Tree and provide a scikit-learn example for Ball Tree.
KD-Tree k-NN Implementation#
The code includes:
A
KDTreeclass to build and query the tree.A
KNNClassifierclass that uses the KD-Tree for efficient neighbor search.
import numpy as np
from collections import Counter
from heapq import heappush, heappop
# KD-Tree Node
class KDNode:
def __init__(self, point, label, axis, left=None, right=None):
self.point = point
self.label = label
self.axis = axis
self.left = left
self.right = right
# KD-Tree for efficient k-NN search
class KDTree:
def __init__(self, k=3, leaf_size=10):
self.k = k
self.leaf_size = leaf_size
def build(self, X, y, depth=0):
if len(X) <= self.leaf_size:
return KDNode(None, None, None) # Leaf node (store points in parent during query)
# Select axis (cycle through dimensions)
axis = depth % X.shape[1]
# Sort by axis and find median
indices = np.argsort(X[:, axis])
X = X[indices]
y = y[indices]
median_idx = len(X) // 2
# Create node and recurse
node = KDNode(X[median_idx], y[median_idx], axis)
node.left = self.build(X[:median_idx], y[:median_idx], depth + 1)
node.right = self.build(X[median_idx + 1:], y[median_idx + 1:], depth + 1)
return node
def query(self, root, point, X, y, depth=0):
if root.left is None and root.right is None: # Leaf node
# Compute distances to all points in the leaf
distances = [np.sqrt(np.sum((point - x) ** 2)) for x in X]
return [(d, y0 if i < len(y) else None) for d, i in sorted(zip(distances, range(len(y))), key=lambda x: x[0])[:self.k]]
axis = root.axis
if point[axis] <= root.point[axis]:
close = root.left
far = root.right
else:
close = root.right
far = root.left
# Search closer subtree
best = self.query(close, point, X, y, depth + 1)
# Check if far subtree could have closer points
current_best_dist = best[0][0] if best else float('inf')
if abs(point[axis] - root.point[axis]) < current_best_dist:
best.extend(self.query(far, point, X, y, depth + 1))
best = sorted(best, key=lambda x: x[0])[:self.k]
return best
# k-NN Classifier using KD-Tree
class KNNClassifier:
def __init__(self, k=3, leaf_size=10):
self.k = k
self.leaf_size = leaf_size
self.kdtree = KDTree(k, leaf_size)
def fit(self, X, y):
self.X_train = X
self.y_train = y
self.root = self.kdtree.build(X, y)
def predict(self, X):
predictions = []
for x in X:
neighbors = self.kdtree.query(self.root, x, self.X_train, self.y_train)
labels = [label for _, label in neighbors if label is not None]
prediction = Counter(labels).most_common(1)[0][0]
predictions.append(prediction)
return np.array(predictions)
# Example usage
if __name__ == "__main__":
# Sample dataset: 2 features, 2 classes
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([0, 0, 0, 1, 1])
X_test = np.array([[4, 4], [5, 5]])
# Initialize and train k-NN
knn = KNNClassifier(k=3, leaf_size=2)
knn.fit(X_train, y_train)
# Predict
predictions = knn.predict(X_test)
print("Predictions:", predictions)
Explanation of Code#
KDNode: Represents a node in the KD-Tree, storing a point, label, split axis, and child nodes.
KDTree:
build: Recursively constructs the tree by splitting data along the median of the current axis. Uses aleaf_sizeparameter to stop splitting when few points remain.query: Finds the \(k\) nearest neighbors by traversing the tree to the leaf, computing distances to points in the leaf, and backtracking to check other branches if needed.
KNNClassifier:
fit: Builds the KD-Tree with training data.predict: Queries the KD-Tree for each test point and uses majority voting to predict the class.
Example: Uses a small dataset with 2 features and 2 classes. Predicts classes for test points
[4, 4]and[5, 5].
Output:
Predictions: [0 1]
[4, 4]is closer to points with label 0.[5, 5]is closer to points with label 1.
Limitations of Scratch Implementation#
Simplified for clarity; lacks optimizations like scikit-learn’s (e.g., no parallelization, basic tie-breaking).
Does not handle edge cases (e.g., empty nodes, duplicate points).
KD-Tree only; Ball Tree is more complex due to hypersphere calculations.
Using scikit-learn’s KD-Tree and Ball Tree#
scikit-learn’s KNeighborsClassifier and KNeighborsRegressor support KD-Tree and Ball Tree via the algorithm parameter. Here’s an example:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
import numpy as np
# Sample dataset
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7]])
y_train = np.array([0, 0, 0, 1, 1])
X_test = np.array([[4, 4], [5, 5]])
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# k-NN with KD-Tree
knn_kdtree = KNeighborsClassifier(n_neighbors=3, algorithm='kd_tree')
knn_kdtree.fit(X_train_scaled, y_train)
predictions_kdtree = knn_kdtree.predict(X_test_scaled)
print("KD-Tree Predictions:", predictions_kdtree)
# k-NN with Ball Tree
knn_balltree = KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree')
knn_balltree.fit(X_train_scaled, y_train)
predictions_balltree = knn_balltree.predict(X_test_scaled)
print("Ball Tree Predictions:", predictions_balltree)
Output:
KD-Tree Predictions: [0 1]
Ball Tree Predictions: [0 1]
scikit-learn Details#
KD-Tree: Uses axis-aligned splits, optimized for low-dimensional Euclidean data. Implemented in Cython for speed.
Ball Tree: Uses hyperspheres, better for high-dimensional or non-Euclidean metrics. Also Cython-based.
Parameters:
algorithm='kd_tree'or'ball_tree'to select the method.leaf_size: Controls the minimum number of points in a leaf (default: 30).
Advantages: Highly optimized, handles edge cases, supports parallelization (
n_jobs).
Practical Considerations#
Choosing KD-Tree vs. Ball Tree:
Use KD-Tree for low-dimensional data (\(n \leq 20\)) with Euclidean distance.
Use Ball Tree for high-dimensional data or non-Euclidean metrics (e.g., cosine).
Test both with cross-validation for your dataset.
Feature Scaling: Always scale features (e.g.,
StandardScaler) since k-NN is distance-based.Performance:
KD-Tree is faster to build and query in low dimensions.
Ball Tree is more robust in high dimensions but slower to construct.
Limitations:
Both degrade in very high dimensions due to the curse of dimensionality.
For very large datasets, consider approximate nearest neighbor methods (e.g., Annoy, HNSW).
Conclusion#
KD-Trees and Ball Trees are powerful data structures for accelerating k-NN by reducing the number of distance calculations. The provided Python implementation demonstrates a KD-Tree-based k-NN classifier, while scikit-learn offers optimized KD-Tree and Ball Tree implementations. KD-Trees are ideal for low-dimensional Euclidean data, while Ball Trees excel in high-dimensional or non-Euclidean settings. By understanding their construction and search logic, you can choose the right method for your k-NN tasks.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.neighbors import KNeighborsClassifier, BallTree
from matplotlib.patches import Rectangle, Circle
from IPython.display import HTML
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
# Sample dataset: 2D, 2 classes
np.random.seed(42)
X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 5], [7, 7], [1, 4], [2, 5]])
y_train = np.array([0, 0, 0, 1, 1, 0, 1])
query_point = np.array([4, 4])
k = 3
# Colors and styles
class_colors = {0: 'blue', 1: 'red'}
query_color = 'green'
neighbor_color = 'yellow'
# --- Brute-force k-NN Animation ---
def animate_brute_force():
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Brute-force k-NN (k=3)')
# Plot training points
for x, y in zip(X_train, y_train):
ax.scatter(x[0], x[1], c=class_colors[y], s=100, label='Class 0' if y == 0 else 'Class 1')
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*', label='Query')
# Calculate distances
distances = [np.sqrt(np.sum((query_point - x) ** 2)) for x in X_train]
sorted_indices = np.argsort(distances)[:k]
lines = []
texts = []
def init():
ax.legend()
return []
def update(frame):
ax.clear()
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Brute-force k-NN (k=3)')
# Plot all points
for x, y in zip(X_train, y_train):
ax.scatter(x[0], x[1], c=class_colors[y], s=100)
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*')
# Draw distance lines up to current frame
for i in range(min(frame + 1, len(X_train))):
x1, y1 = query_point
x2, y2 = X_train[i]
line = ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.5)[0]
lines.append(line)
dist = distances[i]
text = ax.text((x1 + x2) / 2, (y1 + y2) / 2, f'{dist:.2f}', fontsize=8)
texts.append(text)
# Highlight k nearest neighbors
if frame >= len(X_train):
for i in sorted_indices:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=150, alpha=0.5)
return lines + texts
ani = FuncAnimation(fig, update, frames=len(X_train) + 2, init_func=init, blit=False, interval=500)
plt.close() # Prevent static plot display
return HTML(ani.to_html5_video())
# --- KD-Tree k-NN Animation ---
class KDNode:
def __init__(self, point, label, axis, left=None, right=None):
self.point = point
self.label = label
self.axis = axis
self.left = left
self.right = right
class KDTree:
def build(self, X, y, depth=0):
if len(X) == 0:
return None
if len(X) == 1:
return KDNode(X[0], y[0], depth % 2)
axis = depth % 2
indices = np.argsort(X[:, axis])
X = X[indices]
y = y[indices]
median_idx = len(X) // 2
node = KDNode(X[median_idx], y[median_idx], axis)
node.left = self.build(X[:median_idx], y[:median_idx], depth + 1)
node.right = self.build(X[median_idx + 1:], y[median_idx + 1:], depth + 1)
return node
def animate_kd_tree():
kdtree = KDTree()
root = kdtree.build(X_train, y_train)
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('KD-Tree k-NN (k=3)')
splits = []
def collect_splits(node, bounds):
if node is None:
return
x_min, x_max, y_min, y_max = bounds
if node.axis == 0:
splits.append(('x', node.point[0], y_min, y_max))
else:
splits.append(('y', node.point[1], x_min, x_max))
collect_splits(node.left, (x_min, node.point[node.axis], y_min, y_max) if node.axis == 0 else (x_min, x_max, y_min, node.point[node.axis]))
collect_splits(node.right, (node.point[node.axis], x_max, y_min, y_max) if node.axis == 0 else (x_min, x_max, node.point[node.axis], y_max))
collect_splits(root, (0, 8, 0, 8))
# Simulate search path (simplified)
search_path = []
def search(node, point, bounds):
if node is None:
return
search_path.append((node, bounds))
axis = node.axis
if point[axis] <= node.point[axis]:
close = node.left
far = node.right
close_bounds = (bounds[0], node.point[axis], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], bounds[2], node.point[axis])
far_bounds = (node.point[axis], bounds[1], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], node.point[axis], bounds[3])
else:
close = node.right
far = node.left
close_bounds = (node.point[axis], bounds[1], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], node.point[axis], bounds[3])
far_bounds = (bounds[0], node.point[axis], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], bounds[2], node.point[axis])
search(node.left if point[axis] <= node.point[axis] else node.right, point, close_bounds)
search(root, query_point, (0, 8, 0, 8))
# Use scikit-learn for accurate neighbors
knn = KNeighborsClassifier(n_neighbors=k, algorithm='kd_tree')
knn.fit(X_train, y_train)
neighbors = knn.kneighbors([query_point], return_distance=False)[0]
def init():
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=100)
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*')
ax.legend(['Class 0', 'Class 1', 'Query'])
return []
def update(frame):
ax.clear()
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('KD-Tree k-NN (k=3)')
# Plot points
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=100)
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*')
# Draw splits
for i in range(min(frame + 1, len(splits))):
split_type, value, min_bound, max_bound = splits[i]
if split_type == 'x':
ax.axvline(value, ymin=min_bound/8, ymax=max_bound/8, color='black', linestyle='--')
else:
ax.axhline(value, xmin=min_bound/8, xmax=max_bound/8, color='black', linestyle='--')
# Highlight search path
if frame >= len(splits):
for i in range(min(frame - len(splits) + 1, len(search_path))):
node, bounds = search_path[i]
x_min, x_max, y_min, y_max = bounds
rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, fill=False, edgecolor='purple', linewidth=2)
ax.add_patch(rect)
# Highlight neighbors
if frame >= len(splits) + len(search_path):
for i in neighbors:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=150, alpha=0.5)
return []
ani = FuncAnimation(fig, update, frames=len(splits) + len(search_path) + 2, init_func=init, blit=False, interval=1000)
plt.close()
return HTML(ani.to_html5_video())
# --- Ball Tree k-NN Animation ---
def animate_ball_tree():
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Ball Tree k-NN (k=3)')
# Build Ball Tree with scikit-learn
ball_tree = BallTree(X_train, leaf_size=2)
_, neighbors = ball_tree.query([query_point], k=k)
neighbors = neighbors[0]
# Simulate ball construction (simplified)
balls = []
def simulate_balls(X, depth=0):
if len(X) <= 2:
return
centroid = np.mean(X, axis=0)
radius = np.max([np.sqrt(np.sum((x - centroid) ** 2)) for x in X])
balls.append((centroid, radius))
indices = np.argsort([np.sqrt(np.sum((x - centroid) ** 2)) for x in X])
X = X[indices]
mid = len(X) // 2
simulate_balls(X[:mid], depth + 1)
simulate_balls(X[mid:], depth + 1)
simulate_balls(X_train)
def init():
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=100)
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*')
ax.legend(['Class 0', 'Class 1', 'Query'])
return []
def update(frame):
ax.clear()
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Ball Tree k-NN (k=3)')
# Plot points
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=100)
ax.scatter(query_point[0], query_point[1], c=query_color, s=150, marker='*')
# Draw balls
for i in range(min(frame + 1, len(balls))):
centroid, radius = balls[i]
circle = Circle(centroid, radius, fill=False, edgecolor='black', linestyle='--')
ax.add_patch(circle)
# Highlight neighbors
if frame >= len(balls):
for i in neighbors:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=150, alpha=0.5)
return []
ani = FuncAnimation(fig, update, frames=len(balls) + 2, init_func=init, blit=False, interval=1000)
plt.close()
return HTML(ani.to_html5_video())
# Display animations in Jupyter
print("Brute-force k-NN Animation:")
display(animate_brute_force())
print("KD-Tree k-NN Animation:")
display(animate_kd_tree())
print("Ball Tree k-NN Animation:")
display(animate_ball_tree())
Brute-force k-NN Animation:
KD-Tree k-NN Animation:
Ball Tree k-NN Animation:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.spatial import distance
from IPython.display import HTML
# Use inline backend for reliable HTML5 video output
%matplotlib inline
# Generate random 2D points
np.random.seed(42)
points = np.random.rand(20, 2) * 10 # 20 points in [0, 10] x [0, 10]
query_point = np.array([5, 5]) # Query point for k-NN
k = 3 # Number of nearest neighbors
# Simple Ball Tree node class
class BallNode:
def __init__(self, points, indices):
self.points = points
self.indices = indices
self.centroid = np.mean(points[indices], axis=0) if len(indices) > 0 else np.zeros(2)
self.radius = np.max([distance.euclidean(self.centroid, points[i]) for i in indices]) if len(indices) > 0 else 0
self.left = None
self.right = None
self.farthest_a = None
self.farthest_b = None
self.left_indices = []
self.right_indices = []
# Function to build the Ball Tree
def build_ball_tree(points, indices, depth=0, max_depth=3):
if len(indices) <= 1 or depth >= max_depth:
return BallNode(points, indices)
node = BallNode(points, indices)
# Find farthest point A from centroid
distances = [distance.euclidean(node.centroid, points[i]) for i in indices]
idx_a = indices[np.argmax(distances)]
node.farthest_a = idx_a
# Find farthest point B from A
distances = [distance.euclidean(points[idx_a], points[i]) for i in indices]
idx_b = indices[np.argmax(distances)]
node.farthest_b = idx_b
# Partition points based on proximity to A or B
left_indices = [i for i in indices if distance.euclidean(points[i], points[idx_a]) <= distance.euclidean(points[i], points[idx_b])]
right_indices = [i for i in indices if i not in left_indices]
node.left_indices = left_indices
node.right_indices = right_indices
# Recursively build child nodes
if left_indices:
node.left = build_ball_tree(points, left_indices, depth + 1, max_depth)
if right_indices:
node.right = build_ball_tree(points, right_indices, depth + 1, max_depth)
return node
# Build the tree
tree = build_ball_tree(points, list(range(len(points))))
# Simulate k-NN query (simplified for animation)
def knn_query(tree, query, k):
nearest_neighbors = []
def traverse(node, depth=0):
if not node:
return
# Compute distance to centroid
dist_to_centroid = distance.euclidean(query, node.centroid)
# Add points in leaf nodes
if len(node.indices) <= 1:
for idx in node.indices:
dist = distance.euclidean(query, points[idx])
nearest_neighbors.append((dist, idx))
return
# Traverse children (simplified: always explore both for animation)
traverse(node.left, depth + 1)
traverse(node.right, depth + 1)
traverse(tree)
# Sort and get k nearest
nearest_neighbors.sort()
return [idx for _, idx in nearest_neighbors[:k]]
# Get k-nearest neighbors
knn_indices = knn_query(tree, query_point, k)
# Set up the figure and axis
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
# Animation parameters
construction_frames = 12 # 4 steps per depth (points, root, split, children) for 3 depths
query_frames = 5 # Query point, root, traversal, pruning, neighbors
total_frames = construction_frames + query_frames
# Animation update function
def animate(frame):
ax.clear()
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
# Construction Phase
if frame < construction_frames:
depth = frame // 4
step = frame % 4
def draw_node(node, current_depth=0):
if not node or current_depth > depth:
return
# Step 0: Show points
if current_depth == depth and step == 0:
ax.scatter(points[node.indices, 0], points[node.indices, 1], c='blue', s=50)
ax.set_title(f"Depth {depth}: Points")
# Step 1: Show root ball
elif current_depth == depth and step == 1:
circle = plt.Circle(node.centroid, node.radius, fill=False, color='red', linewidth=1.5)
ax.add_patch(circle)
ax.scatter(points[node.indices, 0], points[node.indices, 1], c='blue', s=50)
ax.set_title(f"Depth {depth}: Root Ball")
# Step 2: Highlight farthest points and split
elif current_depth == depth and step == 2:
# Color-code clusters
colors = ['orange' if i in node.left_indices else 'purple' for i in node.indices]
ax.scatter(points[node.indices, 0], points[node.indices, 1], c=colors, s=50)
# Highlight farthest points A and B
ax.scatter(points[node.farthest_a, 0], points[node.farthest_a, 1], c='green', marker='*', s=200, label='Farthest A')
ax.scatter(points[node.farthest_b, 0], points[node.farthest_b, 1], c='yellow', marker='*', s=200, label='Farthest B')
ax.set_title(f"Depth {depth}: Split Clusters")
ax.legend()
# Step 3: Show child balls
elif current_depth == depth and step == 3:
if node.left:
circle = plt.Circle(node.left.centroid, node.left.radius, fill=False, color='red', linewidth=1.5)
ax.add_patch(circle)
if node.right:
circle = plt.Circle(node.right.centroid, node.right.radius, fill=False, color='red', linewidth=1.5)
ax.add_patch(circle)
colors = ['orange' if i in node.left_indices else 'purple' for i in node.indices]
ax.scatter(points[node.indices, 0], points[node.indices, 1], c=colors, s=50)
ax.set_title(f"Depth {depth}: Child Balls")
# Draw children
draw_node(node.left, current_depth + 1)
draw_node(node.right, current_depth + 1)
draw_node(tree)
# Query Phase
else:
query_step = frame - construction_frames
# Step 0: Show query point
if query_step == 0:
ax.scatter(points[:, 0], points[:, 1], c='blue', s=50)
ax.scatter(query_point[0], query_point[1], c='red', marker='*', s=200, label='Query')
ax.set_title("Query Phase: Query Point")
ax.legend()
# Step 1: Highlight root ball and distance
elif query_step == 1:
circle = plt.Circle(tree.centroid, tree.radius, fill=False, color='green', linewidth=1.5)
ax.add_patch(circle)
ax.scatter(points[:, 0], points[:, 1], c='blue', s=50)
ax.scatter(query_point[0], query_point[1], c='red', marker='*', s=200, label='Query')
# Draw line to centroid
ax.plot([query_point[0], tree.centroid[0]], [query_point[1], tree.centroid[1]], 'k--', label='Distance')
ax.set_title("Query Phase: Root Ball")
ax.legend()
# Step 2: Traverse child nodes
elif query_step == 2:
def draw_traversal(node, depth=0):
if not node:
return
circle = plt.Circle(node.centroid, node.radius, fill=False, color='green', linewidth=1.5)
ax.add_patch(circle)
draw_traversal(node.left, depth + 1)
draw_traversal(node.right, depth + 1)
draw_traversal(tree)
ax.scatter(points[:, 0], points[:, 1], c='blue', s=50)
ax.scatter(query_point[0], query_point[1], c='red', marker='*', s=200, label='Query')
ax.set_title("Query Phase: Traverse Nodes")
ax.legend()
# Step 3: Show pruning (simplified: fade non-relevant nodes)
elif query_step == 3:
def draw_pruned(node, depth=0):
if not node:
return
# Simplified: fade leaf nodes not containing neighbors
if len(node.indices) <= 1 and node.indices[0] not in knn_indices:
circle = plt.Circle(node.centroid, node.radius, fill=False, color='gray', alpha=0.3)
else:
circle = plt.Circle(node.centroid, node.radius, fill=False, color='green', linewidth=1.5)
ax.add_patch(circle)
draw_pruned(node.left, depth + 1)
draw_pruned(node.right, depth + 1)
draw_pruned(tree)
ax.scatter(points[:, 0], points[:, 1], c='blue', s=50)
ax.scatter(query_point[0], query_point[1], c='red', marker='*', s=200, label='Query')
ax.set_title("Query Phase: Pruning")
ax.legend()
# Step 4: Highlight k-nearest neighbors
elif query_step == 4:
ax.scatter(points[:, 0], points[:, 1], c='blue', s=50)
ax.scatter(points[knn_indices, 0], points[knn_indices, 1], c='green', s=100, label='Neighbors')
ax.scatter(query_point[0], query_point[1], c='red', marker='*', s=200, label='Query')
ax.set_title(f"Query Phase: {k}-Nearest Neighbors")
ax.legend()
# Create animation
ani = FuncAnimation(fig, animate, frames=total_frames, interval=1000, repeat=False)
# Display animation as HTML5 video
HTML(ani.to_html5_video())
See above figure#
Construction Phase:#
Frame 1: Show all points scattered in 2D.
Frame 2: Draw the root ball (circle enclosing all points).
Frame 3: Highlight the farthest points (A and B) and split points into two clusters (color-coded).
Frame 4: Draw two smaller balls for the child nodes. Repeat for recursive splits, showing circles shrinking and points being partitioned. Use transitions to smoothly draw circles and recolor points.
Query Phase:#
Frame 1: Introduce the query point (e.g., a red star).
Frame 2: Highlight the root ball and show the distance to its centroid.
Frame 3: Traverse to child nodes, highlighting explored balls.
Frame 4: Show pruning (e.g., fade out balls that are too far).
Frame 5: Highlight the k-nearest neighbors as they are found. Use dynamic updates to show distances and the growing list of neighbors.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.neighbors import KNeighborsClassifier, BallTree
from sklearn.datasets import make_classification
from matplotlib.patches import Rectangle, Circle
from IPython.display import HTML
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
# Generate larger dataset: 2D, 2 classes, 50 points
np.random.seed(42)
X_train, y_train = make_classification(n_samples=50, n_features=2, n_classes=2, n_clusters_per_class=1, n_redundant=0, random_state=42)
query_point = np.array([0, 0]) # Center of the data for visibility
k = 3
# Colors and styles
class_colors = {0: 'blue', 1: 'red'}
query_color = 'green'
neighbor_color = 'yellow'
# Compute axis limits
x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
# --- Brute-force k-NN Animation ---
def animate_brute_force():
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Brute-force k-NN (k=3)')
# Plot training points
for x, y in zip(X_train, y_train):
ax.scatter(x[0], x[1], c=class_colors[y], s=50, label='Class 0' if y == 0 else 'Class 1')
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*', label='Query')
# Calculate distances
distances = [np.sqrt(np.sum((query_point - x) ** 2)) for x in X_train]
sorted_indices = np.argsort(distances)[:k]
lines = []
texts = []
def init():
ax.legend()
return []
def update(frame):
ax.clear()
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Brute-force k-NN (k=3)')
# Plot all points
for x, y in zip(X_train, y_train):
ax.scatter(x[0], x[1], c=class_colors[y], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
# Draw distance lines (up to 10 points per frame)
start_idx = (frame * 10) % len(X_train)
end_idx = min(start_idx + 10, len(X_train))
for i in range(start_idx, end_idx):
x1, y1 = query_point
x2, y2 = X_train[i]
line = ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.3)[0]
lines.append(line)
dist = distances[i]
text = ax.text((x1 + x2) / 2, (y1 + y2) / 2, f'{dist:.2f}', fontsize=6)
texts.append(text)
# Highlight k nearest neighbors
if frame >= len(X_train) // 10:
for i in sorted_indices:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=100, alpha=0.5)
return lines + texts
# Enough frames to cycle through all points and show neighbors
ani = FuncAnimation(fig, update, frames=(len(X_train) // 10) + 5, init_func=init, blit=False, interval=800)
plt.close()
return HTML(ani.to_html5_video())
# --- KD-Tree k-NN Animation ---
class KDNode:
def __init__(self, point, label, axis, left=None, right=None):
self.point = point
self.label = label
self.axis = axis
self.left = left
self.right = right
class KDTree:
def build(self, X, y, depth=0):
if len(X) == 0:
return None
if len(X) <= 5: # Increased leaf_size
return KDNode(X[0], y[0], depth % 2)
axis = depth % 2
indices = np.argsort(X[:, axis])
X = X[indices]
y = y[indices]
median_idx = len(X) // 2
node = KDNode(X[median_idx], y[median_idx], axis)
node.left = self.build(X[:median_idx], y[:median_idx], depth + 1)
node.right = self.build(X[median_idx + 1:], y[median_idx + 1:], depth + 1)
return node
def animate_kd_tree():
kdtree = KDTree()
root = kdtree.build(X_train, y_train)
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('KD-Tree k-NN (k=3)')
splits = []
def collect_splits(node, bounds):
if node is None:
return
x_min, x_max, y_min, y_max = bounds
if node.axis == 0:
splits.append(('x', node.point[0], y_min, y_max))
else:
splits.append(('y', node.point[1], x_min, x_max))
collect_splits(node.left, (x_min, node.point[node.axis], y_min, y_max) if node.axis == 0 else (x_min, x_max, y_min, node.point[node.axis]))
collect_splits(node.right, (node.point[node.axis], x_max, y_min, y_max) if node.axis == 0 else (x_min, x_max, node.point[node.axis], y_max))
collect_splits(root, (x_min, x_max, y_min, y_max))
# Simulate search path (simplified)
search_path = []
def search(node, point, bounds):
if node is None:
return
search_path.append((node, bounds))
axis = node.axis
if point[axis] <= node.point[axis]:
close = node.left
far = node.right
close_bounds = (bounds[0], node.point[axis], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], bounds[2], node.point[axis])
far_bounds = (node.point[axis], bounds[1], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], node.point[axis], bounds[3])
else:
close = node.right
far = node.left
close_bounds = (node.point[axis], bounds[1], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], node.point[axis], bounds[3])
far_bounds = (bounds[0], node.point[axis], bounds[2], bounds[3]) if axis == 0 else (bounds[0], bounds[1], bounds[2], node.point[axis])
search(node.left if point[axis] <= node.point[axis] else node.right, point, close_bounds)
search(root, query_point, (x_min, x_max, y_min, y_max))
# Use scikit-learn for accurate neighbors
knn = KNeighborsClassifier(n_neighbors=k, algorithm='kd_tree')
knn.fit(X_train, y_train)
neighbors = knn.kneighbors([query_point], return_distance=False)[0]
def init():
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
ax.legend(['Class 0', 'Class 1', 'Query'])
return []
def update(frame):
ax.clear()
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('KD-Tree k-NN (k=3)')
# Plot points
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
# Draw splits
for i in range(min(frame + 1, len(splits))):
split_type, value, min_bound, max_bound = splits[i]
if split_type == 'x':
ax.axvline(value, ymin=(min_bound-y_min)/(y_max-y_min), ymax=(max_bound-y_min)/(y_max-y_min), color='black', linestyle='--')
else:
ax.axhline(value, xmin=(min_bound-x_min)/(x_max-x_min), xmax=(max_bound-x_min)/(x_max-x_min), color='black', linestyle='--')
# Highlight search path
if frame >= len(splits):
for i in range(min(frame - len(splits) + 1, len(search_path))):
node, bounds = search_path[i]
x_min_b, x_max_b, y_min_b, y_max_b = bounds
rect = Rectangle((x_min_b, y_min_b), x_max_b - x_min_b, y_max_b - y_min_b, fill=False, edgecolor='purple', linewidth=1.5)
ax.add_patch(rect)
# Highlight neighbors
if frame >= len(splits) + len(search_path):
for i in neighbors:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=100, alpha=0.5)
return []
ani = FuncAnimation(fig, update, frames=len(splits) + len(search_path) + 2, init_func=init, blit=False, interval=1200)
plt.close()
return HTML(ani.to_html5_video())
# --- Ball Tree k-NN Animation ---
def animate_ball_tree():
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Ball Tree k-NN (k=3)')
# Build Ball Tree with scikit-learn
ball_tree = BallTree(X_train, leaf_size=5)
_, neighbors = ball_tree.query([query_point], k=k)
neighbors = neighbors[0]
# Simulate ball construction (simplified)
balls = []
def simulate_balls(X, depth=0):
if len(X) <= 5: # Increased leaf_size
return
centroid = np.mean(X, axis=0)
radius = np.max([np.sqrt(np.sum((x - centroid) ** 2)) for x in X]) * 1.2 # Slightly larger radius for visibility
balls.append((centroid, radius))
indices = np.argsort([np.sqrt(np.sum((x - centroid) ** 2)) for x in X])
X = X[indices]
mid = len(X) // 2
simulate_balls(X[:mid], depth + 1)
simulate_balls(X[mid:], depth + 1)
simulate_balls(X_train)
def init():
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
ax.legend(['Class 0', 'Class 1', 'Query'])
return []
def update(frame):
ax.clear()
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('Ball Tree k-NN (k=3)')
# Plot points
ax.scatter(X_train[:, 0], X_train[:, 1], c=[class_colors[y] for y in y_train], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
# Draw balls
for i in range(min(frame + 1, len(balls))):
centroid, radius = balls[i]
circle = Circle(centroid, radius, fill=False, edgecolor='black', linestyle='--', alpha=0.7)
ax.add_patch(circle)
# Highlight neighbors
if frame >= len(balls):
for i in neighbors:
ax.scatter(X_train[i][0], X_train[i][1], c=neighbor_color, s=100, alpha=0.5)
return []
ani = FuncAnimation(fig, update, frames=len(balls) + 2, init_func=init, blit=False, interval=1200)
plt.close()
return HTML(ani.to_html5_video())
# Display animations in Jupyter
print("Brute-force k-NN Animation:")
display(animate_brute_force())
print("KD-Tree k-NN Animation:")
display(animate_kd_tree())
print("Ball Tree k-NN Animation:")
display(animate_ball_tree())
Brute-force k-NN Animation:
KD-Tree k-NN Animation:
Ball Tree k-NN Animation:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.datasets import make_classification, make_regression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, mean_squared_error
from IPython.display import HTML
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
# Generate datasets
np.random.seed(42)
# Classification dataset: 50 points, 2 features, 2 classes
X_class, y_class = make_classification(n_samples=50, n_features=2, n_classes=2, n_clusters_per_class=1, n_redundant=0, random_state=42)
# Regression dataset: 50 points, 2 features, continuous target
X_reg, y_reg = make_regression(n_samples=50, n_features=2, noise=10, random_state=42)
# Query point for both tasks
query_point = np.array([0, 0])
k = 3
# Scale features
scaler_class = StandardScaler()
X_class_scaled = scaler_class.fit_transform(X_class)
query_point_class_scaled = scaler_class.transform([query_point])
scaler_reg = StandardScaler()
X_reg_scaled = scaler_reg.fit_transform(X_reg)
query_point_reg_scaled = scaler_reg.transform([query_point])
# Colors and styles
class_colors = {0: 'blue', 1: 'red'}
query_color = 'green'
neighbor_color = 'yellow'
# Compute axis limits for classification
x_min_class, x_max_class = X_class[:, 0].min() - 1, X_class[:, 0].max() + 1
y_min_class, y_max_class = X_class[:, 1].min() - 1, X_class[:, 1].max() + 1
# Compute axis limits for regression
x_min_reg, x_max_reg = X_reg[:, 0].min() - 1, X_reg[:, 0].max() + 1
y_min_reg, y_max_reg = X_reg[:, 1].min() - 1, X_reg[:, 1].max() + 1
# --- Classification Example ---
def classification_example():
# Train k-NN classifier
knn_class = KNeighborsClassifier(n_neighbors=k)
knn_class.fit(X_class_scaled, y_class)
# Predict for query point
prediction = knn_class.predict(query_point_class_scaled)
print(f"Classification Prediction for query point {query_point}: Class {prediction[0]}")
# Evaluate on training data (for demonstration)
y_pred_train = knn_class.predict(X_class_scaled)
accuracy = accuracy_score(y_class, y_pred_train)
print(f"Classification Training Accuracy: {accuracy:.2f}")
# Animation
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(x_min_class, x_max_class)
ax.set_ylim(y_min_class, y_max_class)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('k-NN Classification (k=3)')
# Plot training points
for x, y in zip(X_class, y_class):
ax.scatter(x[0], x[1], c=class_colors[y], s=50, label='Class 0' if y == 0 else 'Class 1')
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*', label='Query')
# Calculate distances (using scaled data)
distances = [np.sqrt(np.sum((query_point_class_scaled - x) ** 2)) for x in X_class_scaled]
sorted_indices = np.argsort(distances)[:k]
lines = []
texts = []
def init():
ax.legend()
return []
def update(frame):
ax.clear()
ax.set_xlim(x_min_class, x_max_class)
ax.set_ylim(y_min_class, y_max_class)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('k-NN Classification (k=3)')
# Plot all points
for x, y in zip(X_class, y_class):
ax.scatter(x[0], x[1], c=class_colors[y], s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
# Draw distance lines (up to 10 points per frame)
start_idx = (frame * 10) % len(X_class)
end_idx = min(start_idx + 10, len(X_class))
for i in range(start_idx, end_idx):
x1, y1 = query_point
x2, y2 = X_class[i]
line = ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.3)[0]
lines.append(line)
dist = distances[i]
text = ax.text((x1 + x2) / 2, (y1 + y2) / 2, f'{dist:.2f}', fontsize=6)
texts.append(text)
# Highlight k nearest neighbors
if frame >= len(X_class) // 10:
for i in sorted_indices:
ax.scatter(X_class[i][0], X_class[i][1], c=neighbor_color, s=100, alpha=0.5)
return lines + texts
ani = FuncAnimation(fig, update, frames=(len(X_class) // 10) + 5, init_func=init, blit=False, interval=800)
plt.close()
return HTML(ani.to_html5_video())
# --- Regression Example ---
def regression_example():
# Train k-NN regressor
knn_reg = KNeighborsRegressor(n_neighbors=k)
knn_reg.fit(X_reg_scaled, y_reg)
# Predict for query point
prediction = knn_reg.predict(query_point_reg_scaled)
print(f"Regression Prediction for query point {query_point}: {prediction[0]:.2f}")
# Evaluate on training data (for demonstration)
y_pred_train = knn_reg.predict(X_reg_scaled)
mse = mean_squared_error(y_reg, y_pred_train)
print(f"Regression Training MSE: {mse:.2f}")
# Animation
fig, ax = plt.subplots(figsize=(9, 6))
ax.set_xlim(x_min_reg, x_max_reg)
ax.set_ylim(y_min_reg, y_max_reg)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('k-NN Regression (k=3)')
# Plot training points with color gradient based on target value
sc = ax.scatter(X_reg[:, 0], X_reg[:, 1], c=y_reg, cmap='viridis', s=50, label='Training Data')
fig.colorbar(sc, ax=ax, label='Target Value')
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*', label='Query')
# Adjust layout to prevent squishing
fig.tight_layout()
# Calculate distances (using scaled data)
distances = [np.sqrt(np.sum((query_point_reg_scaled - x) ** 2)) for x in X_reg_scaled]
sorted_indices = np.argsort(distances)[:k]
lines = []
texts = []
def init():
ax.legend()
return []
def update(frame):
ax.clear()
ax.set_xlim(x_min_reg, x_max_reg)
ax.set_ylim(y_min_reg, y_max_reg)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_title('k-NN Regression (k=3)')
# Plot all points
sc = ax.scatter(X_reg[:, 0], X_reg[:, 1], c=y_reg, cmap='viridis', s=50)
ax.scatter(query_point[0], query_point[1], c=query_color, s=100, marker='*')
# Draw distance lines (up to 10 points per frame)
start_idx = (frame * 10) % len(X_reg)
end_idx = min(start_idx + 10, len(X_reg))
for i in range(start_idx, end_idx):
x1, y1 = query_point
x2, y2 = X_reg[i]
line = ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.3)[0]
lines.append(line)
dist = distances[i]
text = ax.text((x1 + x2) / 2, (y1 + y2) / 2, f'{dist:.2f}', fontsize=6)
texts.append(text)
# Highlight k nearest neighbors
if frame >= len(X_reg) // 10:
for i in sorted_indices:
ax.scatter(X_reg[i][0], X_reg[i][1], c=neighbor_color, s=100, alpha=0.5)
# Show target value
ax.text(X_reg[i][0], X_reg[i][1] + 0.5, f'{y_reg[i]:.1f}', fontsize=8, color='black')
return lines + texts
ani = FuncAnimation(fig, update, frames=(len(X_reg) // 10) + 5, init_func=init, blit=False, interval=800)
plt.close()
return HTML(ani.to_html5_video())
# Display examples in Jupyter
print("k-NN Classification Example:")
display(classification_example())
print("\nk-NN Regression Example:")
display(regression_example())
k-NN Classification Example:
Classification Prediction for query point [0 0]: Class 1
Classification Training Accuracy: 0.98
k-NN Regression Example:
Regression Prediction for query point [0 0]: -9.27
Regression Training MSE: 107.31
# Your code here