k-means++

Basic

k-means++ | Smarter Seeding Strategies for k-means Clustering

Created: Last updated: Read time: 3 min
まとめ
  • k-means++ spreads the initial centroids apart, reducing the chance that vanilla k-means converges to a poor local optimum.
  • Additional centroids are sampled with probability proportional to the squared distance from the existing centroids, discouraging tight clusters of seeds.
  • In scikit-learn, KMeans(init="k-means++") activates the method, making it easy to compare with purely random initialisation.
  • Large-scale variants such as mini-batch k-means build on k-means++ and are common in streaming or big-data settings.

Intuition #

k-means is notoriously sensitive to the initial centroids. If they all fall inside a single dense region, the algorithm may converge quickly but produce a clustering with large WCSS. k-means++ mitigates this by sampling the first centroid uniformly and subsequent ones with higher probability the farther they are from the already selected centroids.

Mathematical formulation #

Given the set of chosen centroids \({\mu_1, \dots, \mu_m}\), a candidate point \(x\) is selected as the next centroid with probability

$$ P(x) = \frac{D(x)^2}{\sum_{x’ \in \mathcal{X}} D(x’)^2}, \qquad D(x) = \min_{1 \le j \le m} \lVert x - \mu_j \rVert. $$

Points that are far from every existing centroid (large \(D(x)\)) therefore receive higher probability mass, yielding a well-scattered initial configuration (Arthur & Vassilvitskii, 2007).

Experiment in Python #

The snippet below compares random initialisation and k-means++ across three trials, collecting the WCSS after a single k-means iteration.

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs


def create_blobs_dataset(
    n_samples: int = 3000,
    n_centers: int = 8,
    cluster_std: float = 1.5,
    random_state: int = 11711,
) -> tuple[np.ndarray, np.ndarray]:
    """Generate synthetic data for initialisation experiments."""
    return make_blobs(
        n_samples=n_samples,
        centers=n_centers,
        cluster_std=cluster_std,
        random_state=random_state,
    )


def compare_initialisation_strategies(
    data: np.ndarray,
    n_clusters: int = 5,
    subset_size: int = 1000,
    n_trials: int = 3,
    random_state: int = 11711,
) -> dict[str, list[float]]:
    """Compare random versus k-means++ seeding and plot the assignments.

    Args:
        data: Feature matrix to cluster.
        n_clusters: Number of clusters to form.
        subset_size: Number of samples drawn per trial.
        n_trials: Number of comparisons to perform.
        random_state: Random seed for reproducibility.

    Returns:
        Dictionary containing WCSS values for both initialisation modes.
    """
    rng = np.random.default_rng(random_state)
    inertia_random: list[float] = []
    inertia_kpp: list[float] = []

    fig, axes = plt.subplots(
        n_trials,
        2,
        figsize=(10, 3.2 * n_trials),
        sharex=True,
        sharey=True,
    )

    for trial in range(n_trials):
        indices = rng.choice(len(data), size=subset_size, replace=False)
        subset = data[indices]

        random_model = KMeans(
            n_clusters=n_clusters,
            init="random",
            n_init=1,
            max_iter=1,
            random_state=random_state + trial,
        ).fit(subset)
        kpp_model = KMeans(
            n_clusters=n_clusters,
            init="k-means++",
            n_init=1,
            max_iter=1,
            random_state=random_state + trial,
        ).fit(subset)

        inertia_random.append(float(random_model.inertia_))
        inertia_kpp.append(float(kpp_model.inertia_))

        ax_random = axes[trial, 0] if n_trials > 1 else axes[0]
        ax_kpp = axes[trial, 1] if n_trials > 1 else axes[1]

        ax_random.scatter(subset[:, 0], subset[:, 1], c=random_model.labels_, s=10)
        ax_random.set_title(f"Random init (trial {trial + 1})")
        ax_random.grid(alpha=0.2)

        ax_kpp.scatter(subset[:, 0], subset[:, 1], c=kpp_model.labels_, s=10)
        ax_kpp.set_title(f"k-means++ init (trial {trial + 1})")
        ax_kpp.grid(alpha=0.2)

    fig.suptitle("Label assignments after one iteration (random vs k-means++)")
    fig.tight_layout()
    plt.show()

    return {"random": inertia_random, "k-means++": inertia_kpp}


FEATURES, _ = create_blobs_dataset()
results = compare_initialisation_strategies(
    data=FEATURES,
    n_clusters=5,
    subset_size=1000,
    n_trials=3,
    random_state=2024,
)
for method, values in results.items():
    print(f"{method} mean WCSS: {np.mean(values):.1f}")

The snippet below compares random initialisation and k-means… figure

References #

  • Arthur, D., & Vassilvitskii, S. (2007). k-means++: The Advantages of Careful Seeding. ACM-SIAM SODA.
  • Bahmani, B., Moseley, B., Vattani, A., Kumar, R., & Vassilvitskii, S. (2012). Scalable k-means++. VLDB.
  • scikit-learn developers. (2024). Clustering. Retrieved from https://scikit-learn.org/stable/modules/clustering.html