k-means++

Basic

k-means++ | เลือกเซนทรอยด์เริ่มต้นอย่างชาญฉลาด

まとめ
  • k-means++ เลือกเซนทรอยด์เริ่มต้นให้กระจายตัวมากขึ้น ทำให้ k-means ไม่ติดจุดเฉพาะที่ง่าย
  • เซนทรอยด์ใหม่ถูกสุ่มโดยให้น้ำหนักตามระยะกำลังสองจากเซนทรอยด์ที่มีอยู่ จุดที่ไกลกว่าจะถูกเลือกมากกว่า
  • ใน scikit-learn เพียงตั้ง init="k-means++" ก็ได้ผลลัพธ์ทันที สามารถเปรียบเทียบกับการเริ่มแบบสุ่มได้ง่าย
  • สำหรับข้อมูลชุดใหญ่ การประมวลผลแบบ online เช่น mini-batch k-means ก็ใช้แนวคิด k-means++ เป็นฐาน

ภาพรวมเชิงสัญชาติญาณ #

k-means ไวต่อการเลือกจุดเริ่มต้น หากเริ่มไม่ดีอาจฟิตกับคลัสเตอร์ที่บิดเบี้ยวและได้ WCSS สูง k-means++ แก้ด้วยการเลือกเซนทรอยด์แรกแบบสุ่ม จากนั้นสุ่มจุดถัดไปโดยให้น้ำหนักมากกับจุดที่ไกลจากเซนทรอยด์เดิม

สูตรสำคัญ #

ให้เซนทรอยด์ที่เลือกไปแล้วคือ \({\mu_1,\ldots,\mu_m}\) ความน่าจะเป็นที่เลือกจุด \(x\) เป็นเซนทรอยด์ถัดไปคือ

$$ P(x) = \frac{D(x)^2}{\sum_{x’ \in \mathcal{X}} D(x’)^2}, \qquad D(x) = \min_{1 \le j \le m} \lVert x - \mu_j \rVert. $$

จุดที่อยู่ไกลจึงมีโอกาสถูกเลือกสูง ส่งผลให้เซนทรอยด์เริ่มต้นกระจายดีและลดโอกาสได้ผลลัพธ์ไม่พึงประสงค์ (Arthur & Vassilvitskii, 2007)

ทดลองด้วย Python #

ลองเปรียบเทียบสุ่มเริ่มต้นปกติกับ k-means++ อย่างละ 3 ครั้ง

from __future__ import annotations

import japanize_matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs


def create_blobs_dataset(
    n_samples: int = 3000,
    n_centers: int = 8,
    cluster_std: float = 1.5,
    random_state: int = 11711,
) -> tuple[np.ndarray, np.ndarray]:
    """สร้างข้อมูลสังเคราะห์เพื่อเปรียบเทียบการเริ่มต้น"""
    return make_blobs(
        n_samples=n_samples,
        centers=n_centers,
        cluster_std=cluster_std,
        random_state=random_state,
    )


def compare_initialisation_strategies(
    data: np.ndarray,
    n_clusters: int = 5,
    subset_size: int = 1000,
    n_trials: int = 3,
    random_state: int = 11711,
) -> dict[str, list[float]]:
    """เปรียบเทียบ WCSS ของการเริ่มต้นแบบ random และ k-means++"""
    japanize_matplotlib.japanize()
    rng = np.random.default_rng(random_state)
    inertia_random: list[float] = []
    inertia_kpp: list[float] = []

    fig, axes = plt.subplots(
        n_trials,
        2,
        figsize=(10, 3.2 * n_trials),
        sharex=True,
        sharey=True,
    )

    for trial in range(n_trials):
        indices = rng.choice(len(data), size=subset_size, replace=False)
        subset = data[indices]

        random_model = KMeans(
            n_clusters=n_clusters,
            init="random",
            n_init=1,
            max_iter=1,
            random_state=random_state + trial,
        ).fit(subset)
        kpp_model = KMeans(
            n_clusters=n_clusters,
            init="k-means++",
            n_init=1,
            max_iter=1,
            random_state=random_state + trial,
        ).fit(subset)

        inertia_random.append(float(random_model.inertia_))
        inertia_kpp.append(float(kpp_model.inertia_))

        ax_random = axes[trial, 0] if n_trials > 1 else axes[0]
        ax_kpp = axes[trial, 1] if n_trials > 1 else axes[1]

        ax_random.scatter(subset[:, 0], subset[:, 1], c=random_model.labels_, s=10)
        ax_random.set_title(f"random init (trial {trial + 1})")
        ax_random.grid(alpha=0.2)

        ax_kpp.scatter(subset[:, 0], subset[:, 1], c=kpp_model.labels_, s=10)
        ax_kpp.set_title(f"k-means++ init (trial {trial + 1})")
        ax_kpp.grid(alpha=0.2)

    fig.suptitle("เปรียบเทียบการเริ่มต้นของ k-means")
    fig.tight_layout()
    plt.show()

    return {"random": inertia_random, "k-means++": inertia_kpp}


FEATURES, _ = create_blobs_dataset()
metrics = compare_initialisation_strategies(
    data=FEATURES,
    n_clusters=5,
    subset_size=1000,
    n_trials=3,
    random_state=2024,
)
for method, values in metrics.items():
    print(f"{method}: ค่า WCSS เฉลี่ย = {np.mean(values):.1f}")

เปรียบเทียบการเริ่มต้นด้วย random และ k-means++

เอกสารอ้างอิง #

  • Arthur, D., & Vassilvitskii, S. (2007). k-means++: The Advantages of Careful Seeding. ACM-SIAM SODA.
  • Bahmani, B., Moseley, B., Vattani, A., Kumar, R., & Vassilvitskii, S. (2012). Scalable k-means++. VLDB.
  • scikit-learn developers. (2024). Clustering. https://scikit-learn.org/stable/modules/clustering.html