Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Hero image

Efficient Search Structures — KD-Trees, Ball Trees, and Approximate NN

What you will learn: why brute-force KNN is too slow for large datasets, how KD-Trees partition feature space to prune the search, how Ball Trees handle moderate-dimensional and non-Euclidean settings, and when to switch to approximate nearest-neighbour methods.

Business hook

Why Search Speed Matters in Production

An e-commerce recommendation engine must respond in < 50 ms. With n=5,000,000n = 5{,}000{,}000 products and p=128p = 128 embedding dimensions, brute-force KNN requires 5×1085 \times 10^8 multiply-adds per query — taking > 2 seconds on a single CPU core. KD-trees bring this to < 10 ms for typical queries.

1. Brute-Force Complexity

Brute-force KNN computes all nn distances for every query:

Cost per query=O(np)\text{Cost per query} = O(n \cdot p)

For n=106n=10^6 training points, p=50p=50 features: 5×1075 \times 10^7 floating-point operations per query. At 109 ops/s, that is 50 ms per query — unacceptable for real-time APIs.

The goal: reduce average search complexity to O(plogn)O(p \log n) or better, by pruning large portions of the space without computing distances.

%matplotlib inline
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier, KDTree, BallTree

np.random.seed(42)
# Benchmark brute vs KD-tree vs Ball-tree on increasing n
k = 5
p = 10
n_values = [500, 1000, 2000, 5000, 10000, 20000]
times_brute, times_kd, times_ball = [], [], []

for n in n_values:
    X = np.random.randn(n, p)
    q = np.random.randn(100, p)  # 100 query points

    # Brute force
    t0 = time.perf_counter()
    dists = np.linalg.norm(X[None, :, :] - q[:, None, :], axis=2)
    _ = np.argsort(dists, axis=1)[:, :k]
    times_brute.append(time.perf_counter() - t0)

    # KD-Tree
    kd = KDTree(X)
    t0 = time.perf_counter()
    _ = kd.query(q, k=k)
    times_kd.append(time.perf_counter() - t0)

    # Ball Tree
    bt = BallTree(X)
    t0 = time.perf_counter()
    _ = bt.query(q, k=k)
    times_ball.append(time.perf_counter() - t0)

fig, ax = plt.subplots(figsize=(8, 4))
ax.semilogy(n_values, times_brute, 'r-o', label='Brute force O(n·p)')
ax.semilogy(n_values, times_kd,    'b-s', label='KD-Tree O(p·log n)')
ax.semilogy(n_values, times_ball,  'g-^', label='Ball Tree')
ax.set_xlabel('Training set size n')
ax.set_ylabel('Query time for 100 queries (s, log scale)')
ax.set_title(f'Search speed comparison  p={p} dimensions, k={k}')
ax.legend(); ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

2. KD-Trees — Axis-Aligned Space Partitioning

A KD-Tree (k-dimensional tree) is a binary tree that recursively partitions the training set by splitting along one feature axis at each level.

Construction

  1. Choose the splitting dimension (usually the one with highest variance)

  2. Split at the median value → left subtree has points below median, right has points above

  3. Recurse until leaf size ≤ leaf_size (default 30 in sklearn)

Query

For query xqx_q:

  1. Traverse the tree to the leaf that would contain xqx_q

  2. Record best candidates found

  3. Pruning step: at each node, if the distance from xqx_q to the splitting hyperplane is > current best distance, prune that subtree

The pruning means we avoid computing distances to whole subtrees — average case O(plogn)O(p \log n).

Limitation: KD-Trees degrade to O(n)O(n) in high dimensions (p>20p > 20) because the hyperplane splits become too thin to prune effectively.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches


class KDNode:
    def __init__(self, point, left=None, right=None, axis=0):
        self.point = point
        self.left = left
        self.right = right
        self.axis = axis


def build_kdtree(points, depth=0):
    if len(points) == 0:
        return None
    axis = depth % points.shape[1]
    sorted_pts = points[np.argsort(points[:, axis])]
    mid = len(sorted_pts) // 2
    return KDNode(
        point=sorted_pts[mid],
        left=build_kdtree(sorted_pts[:mid], depth + 1),
        right=build_kdtree(sorted_pts[mid + 1:], depth + 1),
        axis=axis
    )


def kdtree_nn(root, query, best=None, best_dist=float('inf')):
    if root is None:
        return best, best_dist
    d = np.linalg.norm(query - root.point)
    if d < best_dist:
        best, best_dist = root.point, d
    axis = root.axis
    diff = query[axis] - root.point[axis]
    near, far = (root.left, root.right) if diff <= 0 else (root.right, root.left)
    best, best_dist = kdtree_nn(near, query, best, best_dist)
    if abs(diff) < best_dist:  # pruning condition
        best, best_dist = kdtree_nn(far, query, best, best_dist)
    return best, best_dist


# Build and visualise KD-Tree splits on 2D data
np.random.seed(7)
pts = np.random.rand(16, 2) * 10
tree = build_kdtree(pts.copy())

query = np.array([5.5, 4.5])
nn_pt, nn_d = kdtree_nn(tree, query)
print(f"Query:   {query}")
print(f"NN:      {nn_pt}  (distance={nn_d:.4f})")

# Verify with scipy
from scipy.spatial import KDTree as ScipyKD
sk = ScipyKD(pts)
d_ref, idx_ref = sk.query(query, k=1)
print(f"Verify:  {pts[idx_ref]}  (scipy dist={d_ref:.4f})")

# Plot
fig, ax = plt.subplots(figsize=(7, 6))
ax.scatter(pts[:,0], pts[:,1], c='steelblue', s=60, zorder=5)
ax.scatter(*query, c='red', s=120, marker='*', label='Query', zorder=6)
ax.scatter(*nn_pt, c='orange', s=100, marker='D', label='NN found', zorder=6)
circle = plt.Circle(query, nn_d, fill=False, color='orange', linestyle='--', label='NN radius')
ax.add_patch(circle)
# Draw first split line
median_x1 = np.median(pts[:,0])
ax.axvline(median_x1, color='gray', linestyle='--', lw=1, label=f'Root split x={median_x1:.1f}')
ax.set_xlim(-0.5, 10.5); ax.set_ylim(-0.5, 10.5)
ax.set_title('KD-Tree: root split + nearest-neighbour search')
ax.legend(fontsize=9); ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()

3. Ball Trees — Hypersphere Partitioning

A Ball Tree partitions the training set into nested hyperspheres (balls) rather than axis-aligned rectangles.

Construction

  1. Find the two farthest points A and B in the current set

  2. Assign each other point to A or B based on which is closer → two clusters

  3. Fit a bounding sphere around each cluster

  4. Recurse

Query pruning

If the distance from query xqx_q to the centre of a ball minus its radius is greater than the current best distance, the entire ball is pruned:

d(xq,centre)r>dbest    skip entire subtreed(x_q, \text{centre}) - r > d_{\text{best}} \implies \text{skip entire subtree}

KD-Tree vs Ball Tree

PropertyKD-TreeBall Tree
PartitioningAxis-aligned hyperrectanglesHyperspheres
Best inLow dimensions (p<20p < 20)Moderate dimensions (p50p \leq 50)
Non-Euclidean metricsNoYes
Construction costO(nlogn)O(n \log n)O(nlogn)O(n \log n)
Query (average)O(plogn)O(p \log n)O(plogn)O(p \log n)
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KDTree, BallTree
import time

np.random.seed(42)
# Compare KD-Tree vs Ball Tree across dimensions
n = 5000
k = 5
dims = [2, 5, 10, 20, 30, 50]
kd_times, bt_times = [], []

for p in dims:
    X = np.random.randn(n, p)
    q = np.random.randn(200, p)

    kd = KDTree(X, leaf_size=30)
    t0 = time.perf_counter()
    kd.query(q, k=k)
    kd_times.append(time.perf_counter() - t0)

    bt = BallTree(X, leaf_size=30)
    t0 = time.perf_counter()
    bt.query(q, k=k)
    bt_times.append(time.perf_counter() - t0)

fig, ax = plt.subplots(figsize=(8, 4))
ax.semilogy(dims, kd_times, 'b-o', label='KD-Tree')
ax.semilogy(dims, bt_times, 'r-s', label='Ball Tree')
ax.set_xlabel('Number of dimensions p')
ax.set_ylabel('Query time for 200 queries (s)')
ax.set_title(f'KD-Tree vs Ball Tree query time  n={n}, k={k}')
ax.legend(); ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("Note: both degrade in high dimensions as pruning becomes less effective.")
for p, kd_t, bt_t in zip(dims, kd_times, bt_times):
    winner = 'KD-Tree' if kd_t < bt_t else 'BallTree'
    print(f"  p={p:3d}: KD={kd_t*1000:.2f}ms  BT={bt_t*1000:.2f}ms  → {winner}")

4. Using sklearn’s KDTree and BallTree

sklearn provides standalone KDTree and BallTree classes, as well as automatic selection through KNeighborsClassifier(algorithm='auto').

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
import time

np.random.seed(42)
X, y = make_classification(n_samples=10000, n_features=10, n_informative=8,
                            n_redundant=2, random_state=42)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_tr_s = scaler.fit_transform(X_tr)
X_te_s = scaler.transform(X_te)

results = {}
for algo in ['brute', 'kd_tree', 'ball_tree', 'auto']:
    knn = KNeighborsClassifier(n_neighbors=7, algorithm=algo)
    t_fit = time.perf_counter(); knn.fit(X_tr_s, y_tr); t_fit = time.perf_counter() - t_fit
    t_pred = time.perf_counter(); y_pred = knn.predict(X_te_s); t_pred = time.perf_counter() - t_pred
    acc = accuracy_score(y_te, y_pred)
    results[algo] = (acc, t_fit, t_pred)
    print(f"{algo:12s}: acc={acc:.4f}  fit={t_fit*1000:.1f}ms  predict={t_pred*1000:.1f}ms")

# All should give same accuracy (exact methods)
print("\nAll exact methods agree:", len(set(r[0] for r in results.values())) == 1)

# Leaf size effect on Ball Tree speed
leaf_sizes = [5, 10, 20, 30, 50, 100]
bt_query_times = []
for ls in leaf_sizes:
    from sklearn.neighbors import BallTree
    bt = BallTree(X_tr_s, leaf_size=ls)
    t0 = time.perf_counter()
    bt.query(X_te_s, k=7)
    bt_query_times.append(time.perf_counter() - t0)

fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(leaf_sizes, [t*1000 for t in bt_query_times], 'r-o')
ax.set_xlabel('leaf_size'); ax.set_ylabel('Query time (ms)')
ax.set_title('Ball Tree query time vs leaf_size  (n=8000, p=10)')
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

5. Approximate Nearest Neighbours

For very large nn (>105> 10^5) or high pp (>50> 50), exact tree methods still become slow. Approximate NN methods trade a small recall reduction for massive speed gains.

Locality-Sensitive Hashing (LSH)

Hash points so that similar points collide in the same bucket with high probability. Query: hash the query, check only points in the same bucket.

Recall95%Speed100× faster than exact KNN\text{Recall} \approx 95\% \quad \text{Speed} \sim 100\times \text{ faster than exact KNN}

Industry-Scale Libraries

LibraryMethodScale
sklearn NearestNeighborsExact (KD/Ball)Up to ~105
faiss (Facebook)IVF + PQBillions of vectors
hnswlibHNSW graph106107
annoy (Spotify)Random projection trees106

For the lab (customer segmentation), the dataset is small enough for sklearn’s exact methods.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

# Demonstrate NearestNeighbors API (exact, but clean interface)
np.random.seed(42)
X = np.random.randn(500, 2)
query_pts = np.array([[0.0, 0.0], [1.5, -1.0], [-2.0, 2.0]])

nn = NearestNeighbors(n_neighbors=5, algorithm='ball_tree', metric='euclidean')
nn.fit(X)
distances, indices = nn.kneighbors(query_pts)

print("Nearest neighbours for 3 query points (k=5):")
for i, (q, dists, idxs) in enumerate(zip(query_pts, distances, indices)):
    print(f"  Query {i}: {q} → neighbours at indices {idxs}")
    print(f"    distances: {np.round(dists, 3)}")

# Visualise
fig, ax = plt.subplots(figsize=(7, 6))
ax.scatter(X[:,0], X[:,1], c='lightgray', s=20, zorder=2, label='Training points')
colors = ['red', 'blue', 'green']
for i, (q, dists, idxs) in enumerate(zip(query_pts, distances, indices)):
    ax.scatter(*q, c=colors[i], s=120, marker='*', zorder=5, label=f'Query {i}')
    for j in idxs:
        ax.plot([q[0], X[j,0]], [q[1], X[j,1]], c=colors[i], alpha=0.5, lw=1)
    circle = plt.Circle(q, dists.max(), fill=False, color=colors[i], linestyle='--', alpha=0.4)
    ax.add_patch(circle)
ax.set_title('NearestNeighbors: 5-NN for 3 queries (Ball Tree)')
ax.legend(fontsize=9); ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()

6. Try It in the Browser

Build a minimal KD-Tree and perform a nearest-neighbour query.

import math

class KDNode:
    def __init__(self, point, left=None, right=None, axis=0):
        self.point = point
        self.left = left
        self.right = right
        self.axis = axis

def build(pts, depth=0):
    if not pts:
        return None
    axis = depth % len(pts[0])
    pts.sort(key=lambda p: p[axis])
    mid = len(pts) // 2
    return KDNode(pts[mid], build(pts[:mid], depth+1), build(pts[mid+1:], depth+1), axis)

def nn_search(node, query, best=None, best_d=float('inf')):
    if node is None:
        return best, best_d
    d = math.sqrt(sum((a-b)**2 for a,b in zip(query, node.point)))
    if d < best_d:
        best, best_d = node.point, d
    axis = node.axis
    diff = query[axis] - node.point[axis]
    near, far = (node.left, node.right) if diff <= 0 else (node.right, node.left)
    best, best_d = nn_search(near, query, best, best_d)
    if abs(diff) < best_d:
        best, best_d = nn_search(far, query, best, best_d)
    return best, best_d

points = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2],[1,9]]
tree = build([list(p) for p in points])
query = [6, 5]
nn, dist = nn_search(tree, query)
print(f"Query:  {query}")
print(f"NN:     {nn}")
print(f"Dist:   {dist:.4f}")

# Brute-force check
bf_nn = min(points, key=lambda p: math.sqrt(sum((a-b)**2 for a,b in zip(query,p))))
print(f"Brute-force check: {bf_nn}")
print(f"Match: {nn == bf_nn}")

Knowledge Check

KD-Trees degrade to O(n) search complexity in high dimensions because:

[ ] A) The tree becomes unbalanced [ ] B) The axis-aligned hyperplane splits become too thin to prune effectively — most subtrees must be visited [ ] C) The median calculation is inaccurate in high dimensions [ ] D) KD-Trees only support Euclidean distance
Check

Ball Trees are preferred over KD-Trees when:

[ ] A) The dataset has fewer than 100 points [ ] B) Using non-Euclidean metrics (e.g. Haversine, cosine) or moderate-dimensional data (p > 20) [ ] C) The features are all binary [ ] D) Ball Trees always outperform KD-Trees
Check

Exercises

Exercise 1 — Leaf Size Tuning

On make_classification(n_samples=20000, n_features=8), benchmark BallTree query time for leaf_size ∈ [5, 10, 20, 40, 80, 160]. Plot query time and accuracy vs leaf_size. What is the optimal leaf_size for minimum query time?

Exercise 2 — Haversine Distance (geo-nearest-neighbour)

Load latitude/longitude data for 1000 random cities (use np.random.uniform). Use BallTree with metric='haversine' to find the 3 nearest cities for a query point. Compare with Euclidean KD-Tree — does the result differ?

Exercise 3 — Approximate vs Exact Recall Trade-off

For n=50000, p=20, compare sklearn BallTree (exact) vs a manual random-projection approximation: project XX to 5 random dimensions, find NN in projected space, then verify true distance. Report approximate recall at K=10.

%matplotlib inline
# Exercises 1, 2, 3 — your code here
from sklearn.neighbors import BallTree, KDTree
from sklearn.datasets import make_classification
import numpy as np
import time
# Your code here

Common Pitfalls

Summary
  • Brute-force KNN is O(np)O(n \cdot p) per query — prohibitive for large nn.

  • KD-Tree: partitions space with axis-aligned hyperplanes, average O(plogn)O(p \log n) query. Best for p<20p < 20.

  • Ball Tree: partitions space with hyperspheres, works with non-Euclidean metrics, better for p50p \leq 50.

  • Both degrade to O(n)O(n) in very high dimensions (curse of dimensionality).

  • For n>105n > 10^5 or p>50p > 50: use approximate NN (FAISS, HNSW, annoy) with ~95% recall and ~100× speedup.

  • sklearn API: KDTree, BallTree, NearestNeighbors, and KNeighborsClassifier(algorithm='auto').

Next steps

What’s Next — KNN Lab: Customer Segmentation

You now have the complete KNN toolkit: algorithm, distance metrics, k selection, and efficient search. The next notebook applies everything to a real business problem: customer segmentation using KNN-based similarity search on retail transaction data. You will:

  • Segment customers by purchase behaviour (RFM features)

  • Find similar customers for targeted marketing

  • Evaluate segmentation quality with silhouette score

Proceed to knn_lab.ipynb.