Validación cruzada estratificada

Eval

Validación cruzada estratificada

Creado: Última actualización: Tiempo de lectura: 2 min
まとめ
  • Stratified k-fold mantiene la proporción de clases en cada pliegue, indispensable en datasets desbalanceados.
  • Compara la variante estratificada con k-fold estándar para visualizar las diferencias de sesgo.
  • Recoge recomendaciones para casos de desequilibrio extremo y cómo interpretar los resultados en la práctica.
import matplotlib.pyplot as plt
import numpy as np
import japanize_matplotlib
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score

RND = 42

Construcción del modelo y validación cruzada #

Dataset de experimentación #

n_classes = 10
X, y = make_classification(
    n_samples=210,
    n_classes=n_classes,
    n_informative=n_classes,
    n_features=12,
    n_clusters_per_class=1,
    weights=[0.82, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
    random_state=RND,
)
plt.title("Número de muestras por clase en y")
plt.hist(y)
plt.xlabel("Etiqueta")
plt.ylabel("Cantidad")
plt.show()

png

Proporciones tras el split #

Dividimos los datos y comprobamos las proporciones de clase en los conjuntos de entrenamiento y validación.

StratifiedKFold #

La proporción de clases se mantiene estable en entrenamiento y validación.

skf = StratifiedKFold(n_splits=4)
for train_index, valid_index in skf.split(X, y):
    X_train, X_valid = X[train_index], X[valid_index]
    y_train, y_valid = y[train_index], y[valid_index]
    plt.figure(figsize=(8, 2))
    plt.subplot(121)
    plt.title("Entrenamiento")
    train_label_cnt = [(y_train == i).sum() for i in range(n_classes)]
    plt.ylabel("Cantidad")
    plt.bar(np.arange(n_classes), train_label_cnt)
    plt.subplot(122)
    plt.title("Validación")
    valid_label_cnt = [(y_valid == i).sum() for i in range(n_classes)]
    plt.bar(np.arange(n_classes), valid_label_cnt)
    plt.show()

png

KFold #

K-fold estándar puede generar pliegues que carecen por completo de algunas clases minoritarias.

kf = KFold(n_splits=4)
for train_index, valid_index in kf.split(X, y):
    X_train, X_valid = X[train_index], X[valid_index]
    y_train, y_valid = y[train_index], y[valid_index]
    plt.figure(figsize=(8, 2))
    plt.subplot(121)
    plt.title("Entrenamiento")
    train_label_cnt = [(y_train == i).sum() for i in range(n_classes)]
    plt.ylabel("Cantidad")
    plt.bar(np.arange(n_classes), train_label_cnt)
    plt.subplot(122)
    plt.title("Validación")
    valid_label_cnt = [(y_valid == i).sum() for i in range(n_classes)]
    plt.bar(np.arange(n_classes), valid_label_cnt)
    plt.show()

png


Consideraciones prácticas #

  • Desequilibrio extremo: cuando las clases minoritarias apenas tienen muestras, combina estratificación con validación cruzada repetida para reducir la varianza.
  • Regresión: discretiza el objetivo en bins y aplica StratifiedKFold si necesitas pliegues equilibrados.
  • Política de mezclado: activa shuffle=True (con semilla fija) cuando el dataset tiene orden temporal o por grupos que puedan sesgar los pliegues.

Stratified k-fold es un reemplazo directo de k-fold cuando importa el balance de clases. Produce divisiones de validación más justas, estabiliza métricas como ROC-AUC y mejora la comparabilidad entre modelos entrenados sobre datos desbalanceados.