
Efficient Search Structures — KD-Trees, Ball Trees, and Approximate NN¶
What you will learn: why brute-force KNN is too slow for large datasets, how KD-Trees partition feature space to prune the search, how Ball Trees handle moderate-dimensional and non-Euclidean settings, and when to switch to approximate nearest-neighbour methods.

Why Search Speed Matters in Production¶
An e-commerce recommendation engine must respond in < 50 ms. With products and embedding dimensions, brute-force KNN requires multiply-adds per query — taking > 2 seconds on a single CPU core. KD-trees bring this to < 10 ms for typical queries.
1. Brute-Force Complexity¶
Brute-force KNN computes all distances for every query:
For training points, features: floating-point operations per query. At 109 ops/s, that is 50 ms per query — unacceptable for real-time APIs.
The goal: reduce average search complexity to or better, by pruning large portions of the space without computing distances.
%matplotlib inline
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier, KDTree, BallTree
np.random.seed(42)
# Benchmark brute vs KD-tree vs Ball-tree on increasing n
k = 5
p = 10
n_values = [500, 1000, 2000, 5000, 10000, 20000]
times_brute, times_kd, times_ball = [], [], []
for n in n_values:
X = np.random.randn(n, p)
q = np.random.randn(100, p) # 100 query points
# Brute force
t0 = time.perf_counter()
dists = np.linalg.norm(X[None, :, :] - q[:, None, :], axis=2)
_ = np.argsort(dists, axis=1)[:, :k]
times_brute.append(time.perf_counter() - t0)
# KD-Tree
kd = KDTree(X)
t0 = time.perf_counter()
_ = kd.query(q, k=k)
times_kd.append(time.perf_counter() - t0)
# Ball Tree
bt = BallTree(X)
t0 = time.perf_counter()
_ = bt.query(q, k=k)
times_ball.append(time.perf_counter() - t0)
fig, ax = plt.subplots(figsize=(8, 4))
ax.semilogy(n_values, times_brute, 'r-o', label='Brute force O(n·p)')
ax.semilogy(n_values, times_kd, 'b-s', label='KD-Tree O(p·log n)')
ax.semilogy(n_values, times_ball, 'g-^', label='Ball Tree')
ax.set_xlabel('Training set size n')
ax.set_ylabel('Query time for 100 queries (s, log scale)')
ax.set_title(f'Search speed comparison p={p} dimensions, k={k}')
ax.legend(); ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
2. KD-Trees — Axis-Aligned Space Partitioning¶
A KD-Tree (k-dimensional tree) is a binary tree that recursively partitions the training set by splitting along one feature axis at each level.
Construction¶
Choose the splitting dimension (usually the one with highest variance)
Split at the median value → left subtree has points below median, right has points above
Recurse until leaf size ≤
leaf_size(default 30 in sklearn)
Query¶
For query :
Traverse the tree to the leaf that would contain
Record best candidates found
Pruning step: at each node, if the distance from to the splitting hyperplane is > current best distance, prune that subtree
The pruning means we avoid computing distances to whole subtrees — average case .
Limitation: KD-Trees degrade to in high dimensions () because the hyperplane splits become too thin to prune effectively.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
class KDNode:
def __init__(self, point, left=None, right=None, axis=0):
self.point = point
self.left = left
self.right = right
self.axis = axis
def build_kdtree(points, depth=0):
if len(points) == 0:
return None
axis = depth % points.shape[1]
sorted_pts = points[np.argsort(points[:, axis])]
mid = len(sorted_pts) // 2
return KDNode(
point=sorted_pts[mid],
left=build_kdtree(sorted_pts[:mid], depth + 1),
right=build_kdtree(sorted_pts[mid + 1:], depth + 1),
axis=axis
)
def kdtree_nn(root, query, best=None, best_dist=float('inf')):
if root is None:
return best, best_dist
d = np.linalg.norm(query - root.point)
if d < best_dist:
best, best_dist = root.point, d
axis = root.axis
diff = query[axis] - root.point[axis]
near, far = (root.left, root.right) if diff <= 0 else (root.right, root.left)
best, best_dist = kdtree_nn(near, query, best, best_dist)
if abs(diff) < best_dist: # pruning condition
best, best_dist = kdtree_nn(far, query, best, best_dist)
return best, best_dist
# Build and visualise KD-Tree splits on 2D data
np.random.seed(7)
pts = np.random.rand(16, 2) * 10
tree = build_kdtree(pts.copy())
query = np.array([5.5, 4.5])
nn_pt, nn_d = kdtree_nn(tree, query)
print(f"Query: {query}")
print(f"NN: {nn_pt} (distance={nn_d:.4f})")
# Verify with scipy
from scipy.spatial import KDTree as ScipyKD
sk = ScipyKD(pts)
d_ref, idx_ref = sk.query(query, k=1)
print(f"Verify: {pts[idx_ref]} (scipy dist={d_ref:.4f})")
# Plot
fig, ax = plt.subplots(figsize=(7, 6))
ax.scatter(pts[:,0], pts[:,1], c='steelblue', s=60, zorder=5)
ax.scatter(*query, c='red', s=120, marker='*', label='Query', zorder=6)
ax.scatter(*nn_pt, c='orange', s=100, marker='D', label='NN found', zorder=6)
circle = plt.Circle(query, nn_d, fill=False, color='orange', linestyle='--', label='NN radius')
ax.add_patch(circle)
# Draw first split line
median_x1 = np.median(pts[:,0])
ax.axvline(median_x1, color='gray', linestyle='--', lw=1, label=f'Root split x={median_x1:.1f}')
ax.set_xlim(-0.5, 10.5); ax.set_ylim(-0.5, 10.5)
ax.set_title('KD-Tree: root split + nearest-neighbour search')
ax.legend(fontsize=9); ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()
3. Ball Trees — Hypersphere Partitioning¶
A Ball Tree partitions the training set into nested hyperspheres (balls) rather than axis-aligned rectangles.
Construction¶
Find the two farthest points A and B in the current set
Assign each other point to A or B based on which is closer → two clusters
Fit a bounding sphere around each cluster
Recurse
Query pruning¶
If the distance from query to the centre of a ball minus its radius is greater than the current best distance, the entire ball is pruned:
KD-Tree vs Ball Tree¶
| Property | KD-Tree | Ball Tree |
|---|---|---|
| Partitioning | Axis-aligned hyperrectangles | Hyperspheres |
| Best in | Low dimensions () | Moderate dimensions () |
| Non-Euclidean metrics | No | Yes |
| Construction cost | ||
| Query (average) |
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KDTree, BallTree
import time
np.random.seed(42)
# Compare KD-Tree vs Ball Tree across dimensions
n = 5000
k = 5
dims = [2, 5, 10, 20, 30, 50]
kd_times, bt_times = [], []
for p in dims:
X = np.random.randn(n, p)
q = np.random.randn(200, p)
kd = KDTree(X, leaf_size=30)
t0 = time.perf_counter()
kd.query(q, k=k)
kd_times.append(time.perf_counter() - t0)
bt = BallTree(X, leaf_size=30)
t0 = time.perf_counter()
bt.query(q, k=k)
bt_times.append(time.perf_counter() - t0)
fig, ax = plt.subplots(figsize=(8, 4))
ax.semilogy(dims, kd_times, 'b-o', label='KD-Tree')
ax.semilogy(dims, bt_times, 'r-s', label='Ball Tree')
ax.set_xlabel('Number of dimensions p')
ax.set_ylabel('Query time for 200 queries (s)')
ax.set_title(f'KD-Tree vs Ball Tree query time n={n}, k={k}')
ax.legend(); ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("Note: both degrade in high dimensions as pruning becomes less effective.")
for p, kd_t, bt_t in zip(dims, kd_times, bt_times):
winner = 'KD-Tree' if kd_t < bt_t else 'BallTree'
print(f" p={p:3d}: KD={kd_t*1000:.2f}ms BT={bt_t*1000:.2f}ms → {winner}")
4. Using sklearn’s KDTree and BallTree¶
sklearn provides standalone KDTree and BallTree classes, as well as automatic selection through KNeighborsClassifier(algorithm='auto').
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
import time
np.random.seed(42)
X, y = make_classification(n_samples=10000, n_features=10, n_informative=8,
n_redundant=2, random_state=42)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_tr_s = scaler.fit_transform(X_tr)
X_te_s = scaler.transform(X_te)
results = {}
for algo in ['brute', 'kd_tree', 'ball_tree', 'auto']:
knn = KNeighborsClassifier(n_neighbors=7, algorithm=algo)
t_fit = time.perf_counter(); knn.fit(X_tr_s, y_tr); t_fit = time.perf_counter() - t_fit
t_pred = time.perf_counter(); y_pred = knn.predict(X_te_s); t_pred = time.perf_counter() - t_pred
acc = accuracy_score(y_te, y_pred)
results[algo] = (acc, t_fit, t_pred)
print(f"{algo:12s}: acc={acc:.4f} fit={t_fit*1000:.1f}ms predict={t_pred*1000:.1f}ms")
# All should give same accuracy (exact methods)
print("\nAll exact methods agree:", len(set(r[0] for r in results.values())) == 1)
# Leaf size effect on Ball Tree speed
leaf_sizes = [5, 10, 20, 30, 50, 100]
bt_query_times = []
for ls in leaf_sizes:
from sklearn.neighbors import BallTree
bt = BallTree(X_tr_s, leaf_size=ls)
t0 = time.perf_counter()
bt.query(X_te_s, k=7)
bt_query_times.append(time.perf_counter() - t0)
fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(leaf_sizes, [t*1000 for t in bt_query_times], 'r-o')
ax.set_xlabel('leaf_size'); ax.set_ylabel('Query time (ms)')
ax.set_title('Ball Tree query time vs leaf_size (n=8000, p=10)')
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
5. Approximate Nearest Neighbours¶
For very large () or high (), exact tree methods still become slow. Approximate NN methods trade a small recall reduction for massive speed gains.
Locality-Sensitive Hashing (LSH)¶
Hash points so that similar points collide in the same bucket with high probability. Query: hash the query, check only points in the same bucket.
Industry-Scale Libraries¶
| Library | Method | Scale |
|---|---|---|
sklearn NearestNeighbors | Exact (KD/Ball) | Up to ~105 |
faiss (Facebook) | IVF + PQ | Billions of vectors |
hnswlib | HNSW graph | 106–107 |
annoy (Spotify) | Random projection trees | 106 |
For the lab (customer segmentation), the dataset is small enough for sklearn’s exact methods.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
# Demonstrate NearestNeighbors API (exact, but clean interface)
np.random.seed(42)
X = np.random.randn(500, 2)
query_pts = np.array([[0.0, 0.0], [1.5, -1.0], [-2.0, 2.0]])
nn = NearestNeighbors(n_neighbors=5, algorithm='ball_tree', metric='euclidean')
nn.fit(X)
distances, indices = nn.kneighbors(query_pts)
print("Nearest neighbours for 3 query points (k=5):")
for i, (q, dists, idxs) in enumerate(zip(query_pts, distances, indices)):
print(f" Query {i}: {q} → neighbours at indices {idxs}")
print(f" distances: {np.round(dists, 3)}")
# Visualise
fig, ax = plt.subplots(figsize=(7, 6))
ax.scatter(X[:,0], X[:,1], c='lightgray', s=20, zorder=2, label='Training points')
colors = ['red', 'blue', 'green']
for i, (q, dists, idxs) in enumerate(zip(query_pts, distances, indices)):
ax.scatter(*q, c=colors[i], s=120, marker='*', zorder=5, label=f'Query {i}')
for j in idxs:
ax.plot([q[0], X[j,0]], [q[1], X[j,1]], c=colors[i], alpha=0.5, lw=1)
circle = plt.Circle(q, dists.max(), fill=False, color=colors[i], linestyle='--', alpha=0.4)
ax.add_patch(circle)
ax.set_title('NearestNeighbors: 5-NN for 3 queries (Ball Tree)')
ax.legend(fontsize=9); ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()
6. Try It in the Browser¶
Build a minimal KD-Tree and perform a nearest-neighbour query.
import math
class KDNode:
def __init__(self, point, left=None, right=None, axis=0):
self.point = point
self.left = left
self.right = right
self.axis = axis
def build(pts, depth=0):
if not pts:
return None
axis = depth % len(pts[0])
pts.sort(key=lambda p: p[axis])
mid = len(pts) // 2
return KDNode(pts[mid], build(pts[:mid], depth+1), build(pts[mid+1:], depth+1), axis)
def nn_search(node, query, best=None, best_d=float('inf')):
if node is None:
return best, best_d
d = math.sqrt(sum((a-b)**2 for a,b in zip(query, node.point)))
if d < best_d:
best, best_d = node.point, d
axis = node.axis
diff = query[axis] - node.point[axis]
near, far = (node.left, node.right) if diff <= 0 else (node.right, node.left)
best, best_d = nn_search(near, query, best, best_d)
if abs(diff) < best_d:
best, best_d = nn_search(far, query, best, best_d)
return best, best_d
points = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2],[1,9]]
tree = build([list(p) for p in points])
query = [6, 5]
nn, dist = nn_search(tree, query)
print(f"Query: {query}")
print(f"NN: {nn}")
print(f"Dist: {dist:.4f}")
# Brute-force check
bf_nn = min(points, key=lambda p: math.sqrt(sum((a-b)**2 for a,b in zip(query,p))))
print(f"Brute-force check: {bf_nn}")
print(f"Match: {nn == bf_nn}")Knowledge Check¶
KD-Trees degrade to O(n) search complexity in high dimensions because:¶
CheckBall Trees are preferred over KD-Trees when:¶
CheckExercises¶
Exercise 1 — Leaf Size Tuning¶
On make_classification(n_samples=20000, n_features=8), benchmark BallTree query time for leaf_size ∈ [5, 10, 20, 40, 80, 160]. Plot query time and accuracy vs leaf_size. What is the optimal leaf_size for minimum query time?
Exercise 2 — Haversine Distance (geo-nearest-neighbour)¶
Load latitude/longitude data for 1000 random cities (use np.random.uniform). Use BallTree with metric='haversine' to find the 3 nearest cities for a query point. Compare with Euclidean KD-Tree — does the result differ?
Exercise 3 — Approximate vs Exact Recall Trade-off¶
For n=50000, p=20, compare sklearn BallTree (exact) vs a manual random-projection approximation: project to 5 random dimensions, find NN in projected space, then verify true distance. Report approximate recall at K=10.
%matplotlib inline
# Exercises 1, 2, 3 — your code here
from sklearn.neighbors import BallTree, KDTree
from sklearn.datasets import make_classification
import numpy as np
import time
# Your code here
Common Pitfalls¶
Summary
Brute-force KNN is per query — prohibitive for large .
KD-Tree: partitions space with axis-aligned hyperplanes, average query. Best for .
Ball Tree: partitions space with hyperspheres, works with non-Euclidean metrics, better for .
Both degrade to in very high dimensions (curse of dimensionality).
For or : use approximate NN (FAISS, HNSW, annoy) with ~95% recall and ~100× speedup.
sklearn API:
KDTree,BallTree,NearestNeighbors, andKNeighborsClassifier(algorithm='auto').

What’s Next — KNN Lab: Customer Segmentation¶
You now have the complete KNN toolkit: algorithm, distance metrics, k selection, and efficient search. The next notebook applies everything to a real business problem: customer segmentation using KNN-based similarity search on retail transaction data. You will:
Segment customers by purchase behaviour (RFM features)
Find similar customers for targeted marketing
Evaluate segmentation quality with silhouette score
Proceed to knn_lab.ipynb.