Learning curve

Eval

Learning curve

まとめ
  • A learning curve plots training vs. validation performance as the training sample size grows.
  • Use learning_curve to draw both curves, inspect bias/variance behaviour, and judge whether more data will help.
  • Apply the insights to data collection, model capacity, and feature engineering decisions.

1. What is a learning curve? #

A learning curve tracks training score and validation score while gradually increasing the number of training samples. It helps answer:

  • Is the model underfitting (high bias) or overfitting (high variance)?
  • Would collecting more data meaningfully improve performance?
  • Should we revisit hyperparameters or model architecture?

2. Python example (Ridge regression) #

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_regression
from sklearn.linear_model import Ridge
from sklearn.model_selection import learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


def plot_learning_curve_for_ridge() -> None:
    """Plot RMSE versus training sample size."""
    features, targets = make_regression(
        n_samples=1500,
        n_features=25,
        n_informative=8,
        noise=12.0,
        random_state=42,
    )

    model = make_pipeline(StandardScaler(), Ridge(alpha=10.0))
    train_sizes, train_scores, valid_scores = learning_curve(
        estimator=model,
        X=features,
        y=targets,
        train_sizes=np.linspace(0.1, 1.0, 8),
        cv=5,
        scoring="neg_mean_squared_error",
        shuffle=True,
        random_state=42,
        n_jobs=None,
    )

    train_rmse = np.sqrt(-train_scores.mean(axis=1))
    valid_rmse = np.sqrt(-valid_scores.mean(axis=1))
    train_std = np.sqrt(train_scores.var(axis=1))
    valid_std = np.sqrt(valid_scores.var(axis=1))

    plt.figure(figsize=(7, 5))
    plt.plot(train_sizes, train_rmse, color="#1d4ed8", label="Train RMSE")
    plt.fill_between(
        train_sizes,
        train_rmse - train_std,
        train_rmse + train_std,
        alpha=0.2,
        color="#1d4ed8",
    )
    plt.plot(train_sizes, valid_rmse, color="#ea580c", label="Validation RMSE")
    plt.fill_between(
        train_sizes,
        valid_rmse - valid_std,
        valid_rmse + valid_std,
        alpha=0.2,
        color="#ea580c",
    )
    plt.xlabel("Training samples")
    plt.ylabel("RMSE")
    plt.title("Learning Curve for Ridge Regression (RMSE)")
    plt.legend(loc="upper right")
    plt.grid(alpha=0.3)


plot_learning_curve_for_ridge()
Learning curve for Ridge regression

As sample size increases, training RMSE worsens while validation RMSE improves and eventually stabilises. Once the curves converge, adding more data yields diminishing returns.


3. Interpreting the curves #

  • High variance / overfitting: training score is very good but validation score lags far behind. Try stronger regularisation, fewer features, or more data.
  • High bias / underfitting: both curves are high (poor). Use a more expressive model, engineer features, or loosen regularisation.
  • Converged curves: training and validation scores meet; more data will not change much and you may need a different model or features.

4. Practical applications #

  • Data collection ROI: if the validation curve is still improving, additional data is valuable; if it plateaus, prioritise other work.
  • Model capacity & regularisation: inspect the curve before adjusting tree depth, neural network width, or regularisation strength.
  • Feature engineering: when both curves run parallel at a high error level, richer features can unlock performance.
  • Combine with other diagnostics: pair with validation curves or time-series evaluations to plan iteration cycles.

Summary #

  • Learning curves expose overfitting vs. underfitting and quantify the impact of training data size.
  • learning_curve makes it easy to generate the plot; leverage it when planning data acquisition and tuning cadence.
  • Use it alongside other diagnostics to balance data, model complexity, and business priorities.