การถดถอยพหุนาม

Basic

การถดถอยพหุนาม | จับรูปแบบไม่เชิงเส้นด้วยโมเดลเชิงเส้น

まとめ
  • การถดถอยพหุนามสร้างฟีเจอร์กำลังต่างๆ แล้วส่งให้โมเดลเชิงเส้น ทำให้จับความสัมพันธ์ที่ไม่เชิงเส้นได้
  • โมเดลยังคงเป็นการรวมเชิงเส้นของสัมประสิทธิ์ จึงคงข้อดีด้านคำตอบแบบปิดรูปและการตีความของการถดถอยเชิงเส้น
  • ยิ่งเพิ่มอันดับพหุนาม ยิ่งมีพลังในการแทนค่า แต่ก็เสี่ยงเกิด overfitting จึงควบคุมด้วยการทำให้เป็นระเบียบหรือ cross-validation
  • ควรทำมาตรฐานฟีเจอร์ก่อน และเลือกอันดับ/ความแรงของการทำให้เป็นระเบียบอย่างระมัดระวังเพื่อให้ผลพยากรณ์นิ่ง

ภาพรวมเชิงสัญชาติญาณ #

ถ้าเส้นตรงอธิบายข้อมูลไม่ได้ เช่น มีรูปโดมหรือคลื่น เราสามารถขยายอินพุตเป็น \(x, x^2, x^3, \dots\) แล้วส่งให้โมเดลเชิงเส้นแทน สำหรับปัญหาหลายตัวแปรก็ใส่เทอมไขว้และกำลังของแต่ละตัวแปรเช่นเดียวกัน

สูตรสำคัญ #

ให้เวกเตอร์อินพุต \(\mathbf{x} = (x_1, \dots, x_m)\) และกำหนดอันดับพหุนาม \(d\) เราสร้างแผนที่ฟีเจอร์ \(\phi(\mathbf{x})\) และทำการถดถอยเชิงเส้นบนฟีเจอร์เหล่านี้ เช่น เมื่อ \(m = 2, d = 2\)

$$ \phi(\mathbf{x}) = (1, x_1, x_2, x_1^2, x_1 x_2, x_2^2) $$

แบบจำลองมีรูป

$$ y = \mathbf{w}^\top \phi(\mathbf{x}) $$

เมื่อเพิ่ม \(d\) จำนวนฟีเจอร์จะโตเร็วมาก จึงมักเริ่มจากอันดับ 2 หรือ 3 แล้วเลือกใช้วิธีทำให้เป็นระเบียบ (เช่น Ridge หรือ Lasso) ร่วมด้วยตามความจำเป็น

ทดลองด้วย Python #

ตัวอย่างต่อไปนี้เพิ่มฟีเจอร์พหุนามอันดับ 3 เพื่อเรียนรู้ความสัมพันธ์รูปเส้นโค้ง

from __future__ import annotations

import japanize_matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures


def compare_polynomial_regression(
    n_samples: int = 200,
    degree: int = 3,
    noise_scale: float = 2.0,
    label_observations: str = "observations",
    label_true_curve: str = "true curve",
    label_linear: str = "linear regression",
    label_poly_template: str = "degree-{degree} polynomial",
) -> tuple[float, float]:
    """Fit linear vs. polynomial regression to a cubic trend and plot the results.

    Args:
        n_samples: Number of synthetic samples generated along the curve.
        degree: Polynomial degree used in the feature expansion.
        noise_scale: Standard deviation of the Gaussian noise added to targets.
        label_observations: Legend label for scatter observations.
        label_true_curve: Legend label for the underlying true curve.
        label_linear: Legend label for the linear regression fit.
        label_poly_template: Format string for the polynomial label.

    Returns:
        A tuple containing the mean-squared errors of (linear, polynomial) models.
    """
    japanize_matplotlib.japanize()
    rng = np.random.default_rng(seed=42)

    x: np.ndarray = np.linspace(-3.0, 3.0, n_samples, dtype=float)
    y_true: np.ndarray = 0.5 * x**3 - 1.2 * x**2 + 2.0 * x + 1.5
    y_noisy: np.ndarray = y_true + rng.normal(scale=noise_scale, size=x.shape)

    X: np.ndarray = x[:, np.newaxis]

    linear_model = LinearRegression()
    linear_model.fit(X, y_noisy)
    poly_model = make_pipeline(
        PolynomialFeatures(degree=degree, include_bias=False),
        LinearRegression(),
    )
    poly_model.fit(X, y_noisy)

    grid: np.ndarray = np.linspace(-3.5, 3.5, 300, dtype=float)[:, np.newaxis]
    linear_pred: np.ndarray = linear_model.predict(grid)
    poly_pred: np.ndarray = poly_model.predict(grid)
    true_curve: np.ndarray = (
        0.5 * grid.ravel()**3 - 1.2 * grid.ravel()**2 + 2.0 * grid.ravel() + 1.5
    )

    linear_mse: float = float(mean_squared_error(y_noisy, linear_model.predict(X)))
    poly_mse: float = float(mean_squared_error(y_noisy, poly_model.predict(X)))

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.scatter(
        X,
        y_noisy,
        s=20,
        color="#ff7f0e",
        alpha=0.6,
        label=label_observations,
    )
    ax.plot(
        grid,
        true_curve,
        color="#2ca02c",
        linewidth=2,
        label=label_true_curve,
    )
    ax.plot(
        grid,
        linear_pred,
        color="#1f77b4",
        linestyle="--",
        linewidth=2,
        label=label_linear,
    )
    ax.plot(
        grid,
        poly_pred,
        color="#d62728",
        linewidth=2,
        label=label_poly_template.format(degree=degree),
    )
    ax.set_xlabel("อินพุต $x$")
    ax.set_ylabel("เอาต์พุต $y$")
    ax.legend()
    fig.tight_layout()
    plt.show()

    return linear_mse, poly_mse


degree = 3
linear_mse, poly_mse = compare_polynomial_regression(
    degree=degree,
    label_observations="ข้อมูลที่สังเกต",
    label_true_curve="เส้นโค้งจริง",
    label_linear="การถดถอยเชิงเส้น",
    label_poly_template="พหุนามอันดับ {degree}",
)
print(f"MSE ของการถดถอยเชิงเส้น: {linear_mse:.3f}")
print(f"MSE ของพหุนามอันดับ {degree}: {poly_mse:.3f}")

เปรียบเทียบโมเดลเชิงเส้นกับพหุนามอันดับ 3 บนข้อมูลเส้นโค้ง

วิเคราะห์ผลลัพธ์ #

  • โมเดลเชิงเส้นทั่วไปไม่สามารถตามส่วนโค้งตรงกลางได้ ในขณะที่พหุนามอันดับ 3 จับรูปทรงได้ใกล้เคียงกับความจริง
  • การเพิ่มอันดับพหุนามช่วยให้ฟิตชุดฝึกดีขึ้นแต่ทำให้การคาดการณ์นอกช่วงไม่เสถียร จึงต้องระวัง overfitting
  • การใช้การทำให้เป็นระเบียบ (เช่น Ridge/Lasso) ในพาเลตเดียวกันช่วยควบคุมโมเดลเมื่อมีฟีเจอร์จำนวนมาก

เอกสารอ้างอิง #

  • Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer.
  • Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning. Springer.