まとめ
- Cross-validation splits the dataset into folds so that model performance can be estimated more reliably.
- Compare it with a single hold-out split and observe how the number of folds affects the result.
- Review practical tips on fold design, computation cost, and how to report the scores.
Cross-validation partitions the sample data, trains the model on part of it, tests on the remainder, and examines the validity of the analysis as a whole. — Cross-validation (Wikipedia)
1. What is cross-validation? #
- Split the data into multiple “folds” and alternate which fold is used for validation.
- A single
train_test_splitcan have high variance; cross-validation averages performance across folds. - Common choices are k-fold and stratified k-fold (preserves class balance in each fold).
2. Basic example on Python 3.13 #
Assuming Python 3.13 and scikit-learn:
python --version # Python 3.13.0
pip install scikit-learn matplotlib
The snippet below compares a single hold-out evaluation with 5-fold cross-validation on a synthetic imbalanced dataset.
from __future__ import annotations
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import cross_validate, train_test_split
RANDOM_STATE = 42
def make_dataset() -> tuple[np.ndarray, np.ndarray]:
"""Generate an imbalanced binary classification dataset."""
features, labels = make_classification(
n_samples=300,
n_classes=2,
weights=[0.2, 0.8],
n_informative=4,
n_features=6,
n_clusters_per_class=2,
shuffle=True,
random_state=RANDOM_STATE,
)
return features, labels
def holdout_score() -> float:
"""Compute ROC-AUC on a single hold-out split."""
features, labels = make_dataset()
x_train, x_valid, y_train, y_valid = train_test_split(
features,
labels,
test_size=0.2,
stratify=labels,
random_state=RANDOM_STATE,
)
model = RandomForestClassifier(max_depth=4, random_state=RANDOM_STATE)
model.fit(x_train, y_train)
predictions = model.predict(x_valid)
return roc_auc_score(y_valid, predictions)
def cross_validation_scores() -> dict[str, float]:
"""Run 5-fold cross-validation and average ROC-AUC and Accuracy."""
features, labels = make_dataset()
model = RandomForestClassifier(max_depth=4, random_state=RANDOM_STATE)
scores = cross_validate(
model,
features,
labels,
cv=5,
scoring=("roc_auc", "accuracy"),
return_train_score=False,
n_jobs=None,
)
return {
"roc_auc": float(np.mean(scores["test_roc_auc"])),
"accuracy": float(np.mean(scores["test_accuracy"])),
}
if __name__ == "__main__":
holdout = holdout_score()
print(f"Hold-out ROC-AUC: {holdout:.3f}")
cv_result = cross_validation_scores()
print(f"5-fold ROC-AUC: {cv_result['roc_auc']:.3f}")
print(f"5-fold Accuracy: {cv_result['accuracy']:.3f}")
Example output:
Hold-out ROC-AUC: 0.528
5-fold ROC-AUC: 0.844
5-fold Accuracy: 0.858
The single hold-out split gives an ROC-AUC close to chance level, while cross-validation yields a much more stable estimate.
3. Design tips #
- Choosing the number of folds
Five or ten folds are typical. For small datasets you can considerLeaveOneOut, but the cost grows drastically. - Stratification
When classes are imbalanced, useStratifiedKFold(or thestratifyargument) to maintain label proportions in each fold. - Multiple metrics
Pass a tuple toscoringto compute several metrics at once. Combining ROC-AUC and Accuracy reveals trade-offs. - Integrate with hyperparameter search
GridSearchCV/RandomizedSearchCVrun cross-validation internally, helping avoid overfitting while tuning.
4. Checklist for real projects #
- Is the split strategy appropriate?
For time-series data, switch toTimeSeriesSplitinstead of random folds. - Are stakeholder metrics covered?
Include the decision-driving metrics inscoringso that reports stay aligned. - Have you estimated runtime?
Cross-validation trains the model k times—plan resources accordingly. - Can others reproduce the setup?
Record the Python version, random seeds, and split configuration in notebooks or scripts.
Summary #
- Cross-validation is a core technique for reducing variance in performance estimates and understanding generalisation.
cross_validatein scikit-learn makes it easy to compute several metrics in one run.- Design the folds, metrics, and compute budget deliberately, and embed the process into your production workflow.