- KL divergence measures how much information is lost when using a reference distribution (Q) to approximate a true distribution (P).
- Compute KL on discrete examples, including smoothing techniques for zero probabilities.
- Apply KL to model evaluation and drift monitoring while noting its asymmetry and instability.
1. Definition #
For discrete distributions (P) and (Q):
$$ \mathrm{KL}(P \parallel Q) = \sum_i p_i \log \frac{p_i}{q_i} $$
- ( \mathrm{KL}(P \parallel Q) = 0 ) when (P = Q).
- Asymmetric: ( \mathrm{KL}(P \parallel Q) \neq \mathrm{KL}(Q \parallel P) ).
- Sensitive to support mismatches: if ( q_i = 0 ) while ( p_i > 0 ), KL diverges to infinity.
2. Computing in Python #
import numpy as np
from scipy.special import rel_entr # element-wise p * log(p / q)
def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
"""KL divergence D(P || Q)."""
p = np.asarray(p, dtype=float)
q = np.asarray(q, dtype=float)
p = p / p.sum()
q = q / q.sum()
epsilon = 1e-12
return float(np.sum(rel_entr(p + epsilon, q + epsilon)))
Adding a small epsilon prevents division by zero when probabilities contain zeros. Adjust smoothing according to domain knowledge.
3. Histogram example #
import matplotlib.pyplot as plt
import japanize_matplotlib # optional for Japanese labels
a = np.array([0.1, 0.2, 0.3, 0.2, 0.1, 0.1])
b = np.array([0.05, 0.1, 0.2, 0.3, 0.3, 0.05])
plt.figure(figsize=(12, 4))
plt.bar(np.arange(a.size) - 0.1, a, width=0.2, label="P")
plt.bar(np.arange(b.size) + 0.1, b, width=0.2, label="Q")
plt.legend()
plt.show()
print(f"KL(P || P) = {kl_divergence(a, a):.4f}")
print(f"KL(P || Q) = {kl_divergence(a, b):.4f}")
print(f"KL(Q || P) = {kl_divergence(b, a):.4f}")
Swapping the order changes the value—interpret KL with the correct “baseline vs. target” perspective.
4. Link to Jensen–Shannon divergence #
To obtain a symmetric, bounded measure, average both directions of KL. The Jensen–Shannon divergence uses:
plt.hist(np.random.normal(1, 1, 1000), alpha=0.85, color="blue")
plt.hist(np.random.normal(4, 1, 1000), alpha=0.85, color="red")
plt.hist(np.random.normal(2.5, 1, 1000), alpha=0.85, color="green")
plt.show()
JSD mitigates infinite values when supports differ and is more practical for many real-world comparisons.
5. Practical use and caveats #
- Generative model evaluation: measure divergence between generated and real distributions during GAN/VAE training.
- Monitoring & drift detection: track how production data diverges from training distributions.
- Language models: compare n-gram distributions or token probabilities.
Use smoothing (e.g., Dirichlet priors) when data are sparse. A high KL does not automatically imply model failure—combine with other metrics.
Summary #
KL divergence captures the relative information difference between two distributions. Its asymmetry and sensitivity to zero support require careful handling, but when paired with smoothing and complementary metrics, it provides valuable insight into distributional shifts.