k-NN

Basic

k-Nearest Neighbors (k-NN) | การเรียนรู้แบบขี้เกียจจากระยะทาง

まとめ
  • k-NN ไม่สร้างโมเดลล่วงหน้า แต่ตัดสินฉลากโดยดูเพื่อนบ้านที่ใกล้ที่สุดขณะทำนาย จึงเรียบง่ายและเข้าใจง่าย
  • พารามิเตอร์หลักมีเพียงจำนวนเพื่อนบ้าน \(k\) และรูปแบบการให้น้ำหนักระยะทาง ทำให้ปรับจูนไม่ยุ่งยาก
  • สามารถสร้างเส้นแบ่งที่ไม่เชิงเส้นได้ตามธรรมชาติ แต่เมื่อมิติสูง ระยะทางมักแยกความแตกต่างได้ยาก
  • การทำมาตรฐานและการเลือกฟีเจอร์ล่วงหน้าช่วยให้การคำนวณระยะทางมีเสถียรภาพและได้ผลดีขึ้น

ภาพรวมเชิงสัญชาติญาณ #

หากเชื่อว่าจุดที่ใกล้กันมักมีฉลากเดียวกัน เราสามารถอนุมานฉลากของจุดใหม่จาก \(k\) ตัวอย่างที่ใกล้ที่สุด ด้วยการโหวตแบบถ่วงน้ำหนักตามระยะทาง วิธีนี้ไม่ต้องฝึกโมเดล จึงมักถูกเรียกว่า “lazy learning”

สูตรสำคัญ #

ให้ \(\mathcal{N}_k(\mathbf{x})\) คือเซตของเพื่อนบ้านที่ใกล้ที่สุด \(k\) จุดของ \(\mathbf{x}\) คะแนนโหวตสำหรับคลาส \(c\) คือ

$$ v_c = \sum_{i \in \mathcal{N}_k(\mathbf{x})} w_i ,\mathbb{1}(y_i = c), $$

โดย \(w_i\) คือค่าน้ำหนัก (เช่น กลับส่วนของระยะทาง) คลาสที่มีคะแนนสูงสุดคือผลพยากรณ์

ทดลองด้วย Python #

โค้ดต่อไปนี้ลองหลายค่า \(k\) แล้ววาดบริเวณการจำแนกของค่าที่ดีที่สุด

from __future__ import annotations

import japanize_matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


def run_knn_demo(
    n_samples: int = 600,
    random_state: int = 7,
    weights: str = "distance",
    k_values: tuple[int, ...] = (1, 3, 5, 7, 11),
    validation_ratio: float = 0.3,
    title: str = "พื้นที่การจำแนกของ k-NN",
    xlabel: str = "คุณลักษณะที่ 1",
    ylabel: str = "คุณลักษณะที่ 2",
    class_label_prefix: str = "คลาส",
) -> dict[str, object]:
    """Evaluate k-NN for several neighbour counts and plot decision regions."""
    japanize_matplotlib.japanize()
    X, y = make_blobs(
        n_samples=n_samples,
        centers=3,
        cluster_std=[1.1, 1.0, 1.2],
        random_state=random_state,
    )

    rng = np.random.default_rng(random_state)
    indices = rng.permutation(len(X))
    split = int(len(X) * (1.0 - validation_ratio))
    train_idx, valid_idx = indices[:split], indices[split:]
    X_train, X_valid = X[train_idx], X[valid_idx]
    y_train, y_valid = y[train_idx], y[valid_idx]

    scores: dict[int, float] = {}
    for k in k_values:
        model = make_pipeline(
            StandardScaler(),
            KNeighborsClassifier(n_neighbors=k, weights=weights),
        )
        model.fit(X_train, y_train)
        scores[k] = float(model.score(X_valid, y_valid))

    best_k = max(scores, key=scores.get)
    best_model = make_pipeline(
        StandardScaler(),
        KNeighborsClassifier(n_neighbors=best_k, weights=weights),
    )
    best_model.fit(X, y)

    xx, yy = np.meshgrid(
        np.linspace(X[:, 0].min() - 1.5, X[:, 0].max() + 1.5, 300),
        np.linspace(X[:, 1].min() - 1.5, X[:, 1].max() + 1.5, 300),
    )
    grid = np.column_stack([xx.ravel(), yy.ravel()])
    predictions = best_model.predict(grid).reshape(xx.shape)

    unique_classes = np.unique(y)
    levels = np.arange(unique_classes.min(), unique_classes.max() + 2) - 0.5
    cmap = ListedColormap(["#fee0d2", "#deebf7", "#c7e9c0"])

    fig, ax = plt.subplots(figsize=(7, 5.5))
    contour = ax.contourf(xx, yy, predictions, levels=levels, cmap=cmap, alpha=0.85)
    scatter = ax.scatter(
        X[:, 0],
        X[:, 1],
        c=y,
        cmap="Set1",
        edgecolor="#1f2937",
        linewidth=0.6,
    )
    ax.set_title(f"{title} (k={best_k}, weights={weights})")
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.grid(alpha=0.15)

    legend = ax.legend(
        handles=scatter.legend_elements()[0],
        labels=[f"{class_label_prefix} {cls}" for cls in unique_classes],
        loc="upper right",
        frameon=True,
    )
    legend.get_frame().set_alpha(0.9)
    fig.colorbar(contour, ax=ax, label="คลาสที่พยากรณ์")
    fig.tight_layout()
    plt.show()

    return {"scores": scores, "best_k": int(best_k), "validation_accuracy": scores[best_k]}


metrics = run_knn_demo(
    title="พื้นที่การจำแนกของ k-NN",
    xlabel="คุณลักษณะที่ 1",
    ylabel="คุณลักษณะที่ 2",
    class_label_prefix="คลาส",
)
print(f"k ที่ดีที่สุด: {metrics['best_k']}")
print(f"ความแม่นยำบนชุดตรวจสอบ: {metrics['validation_accuracy']:.3f}")
for candidate_k, score in metrics["scores"].items():
    print(f"k={candidate_k}: ความแม่นยำตรวจสอบ = {score:.3f}")

บริเวณการจำแนกของ k-NN สำหรับ 3 คลาส

เอกสารอ้างอิง #

  • Cover, T. M., & Hart, P. E. (1967). Nearest Neighbor Pattern Classification. IEEE Transactions on Information Theory, 13(1), 21 E7.
  • Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning. Springer.