Skip to main content
pragmatic data science with python

Model Registries and Drift Detection

10 min read Chapter 32 of 33
Summary

A model in production without a registry is...

A model in production without a registry is a model you cannot reproduce, audit, or roll back. This section builds a complete MLflow workflow: logging parameters, metrics, and artifacts during training, registering model versions with stage transitions, and loading production models by name. It then confronts the three ways models degrade silently. Data drift — when feature distributions shift between training and serving — is detected with the Kolmogorov-Smirnov test and the Population Stability Index. Concept drift — when the relationship between features and targets changes — is detected by monitoring prediction-outcome correlation over time. Feature degradation — when an upstream pipeline breaks and a feature becomes constant, null, or nonsensical — is caught with variance, null rate, and cardinality monitors. Each detector produces actionable alerts with thresholds calibrated to avoid both alert fatigue and missed regressions.

Model Registries and Drift Detection

11.1 — Model Registries

Ask yourself four questions about the model currently serving production traffic. What hyperparameters was it trained with? What version of the training data did it use? How did it perform on the evaluation set? When was it last retrained?

If you cannot answer all four without digging through Slack messages, Jupyter notebooks, or someone’s memory, you have a registry problem. And a registry problem becomes an incident response problem the first time production predictions go wrong and you need to roll back to the last known-good model.

MLflow: Track Everything, Regret Nothing

MLflow is the de facto open-source standard for experiment tracking and model management. It solves three problems: recording what you tried (experiment tracking), storing what you produced (model registry), and loading what you need (model serving).

Here is a complete workflow — training a model, logging everything, registering the result, and loading it for inference:

import mlflow
import mlflow.sklearn
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split


def train_and_register(
    n_estimators: int = 200,
    max_depth: int = 5,
    learning_rate: float = 0.1,
    experiment_name: str = "churn_prediction",
    model_name: str = "churn_model",
) -> str:
    """
    Train a model, log everything to MLflow, and register it.

    Returns the run ID for downstream reference.
    """
    mlflow.set_experiment(experiment_name)

    # Generate realistic data (replace with your actual data loading)
    X, y = make_classification(
        n_samples=10_000, n_features=20, n_informative=12,
        n_redundant=3, weights=[0.7, 0.3], random_state=42,
    )
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=42,
    )

    with mlflow.start_run() as run:
        # Log training parameters — every knob you turned
        mlflow.log_params({
            "n_estimators": n_estimators,
            "max_depth": max_depth,
            "learning_rate": learning_rate,
            "train_samples": len(X_train),
            "test_samples": len(X_test),
            "positive_rate": float(y_train.mean()),
        })

        # Train
        model = GradientBoostingClassifier(
            n_estimators=n_estimators,
            max_depth=max_depth,
            learning_rate=learning_rate,
            random_state=42,
        )
        model.fit(X_train, y_train)

        # Evaluate
        y_pred = model.predict(X_test)
        metrics = {
            "f1": f1_score(y_test, y_pred),
            "precision": precision_score(y_test, y_pred),
            "recall": recall_score(y_test, y_pred),
        }
        mlflow.log_metrics(metrics)

        # Log the model artifact — this stores the serialized model
        mlflow.sklearn.log_model(
            model, "model",
            registered_model_name=model_name,
        )

        # Log feature importance as an artifact
        importance = dict(zip(
            [f"feature_{i}" for i in range(X_train.shape[1])],
            model.feature_importances_.tolist(),
        ))
        mlflow.log_dict(importance, "feature_importance.json")

        print(f"Run ID: {run.info.run_id}")
        print(f"Metrics: {metrics}")
        return run.info.run_id


def load_production_model(model_name: str = "churn_model") -> object:
    """
    Load the production model by name and stage.

    In MLflow, models transition through stages:
    None -> Staging -> Production -> Archived
    """
    model_uri = f"models:/{model_name}/Production"
    model = mlflow.sklearn.load_model(model_uri)
    return model


# Usage:
# run_id = train_and_register(n_estimators=300, max_depth=4)
# model = load_production_model()
# predictions = model.predict(new_data)

Three things to notice. First, log_params captures every decision you made — not just hyperparameters, but data characteristics like the positive class rate and sample counts. When you are debugging a model six months from now, this metadata is the difference between “I think we trained on about 10K samples” and knowing exactly what happened. Second, log_model with registered_model_name does two things in one call: it stores the serialized model as an artifact and registers a new version in the model registry. Third, load_model with the stage URI (models:/churn_model/Production) means your inference code never references a file path. It references a logical name and a stage. You promote a model to production in the MLflow UI, and the inference service picks it up.

Beyond MLflow: Alternatives and Model Cards

Weights & Biases offers richer visualization and better team collaboration features, but it is a SaaS product with usage-based pricing. Neptune provides similar capabilities with a focus on experiment comparison. DVC is the lightweight alternative — it versions data and models using Git-like semantics but lacks the experiment tracking UI and model registry. For teams under ten people with straightforward workflows, DVC plus a naming convention can be sufficient. For anything larger, the registry abstraction that MLflow provides saves more time than it costs.

Regardless of your registry choice, every registered model should have a model card: a document that states the model’s intended use, known limitations, training data characteristics, fairness considerations, and performance across subgroups. A model card is not bureaucracy — it is the documentation that prevents someone from using your churn model to make lending decisions, or deploying a model trained on US data to serve European users without revalidation.


11.2 — Detecting Drift

Your model learned a function that maps features to a target. That function was correct for the training data distribution. The moment production data deviates from that distribution, the function’s guarantees evaporate. This deviation comes in three forms, each with different causes and different detectors.

Data Drift: Features Shift

Data drift occurs when the distribution of input features changes between training and serving. The model itself has not changed. The relationship between features and target has not changed. But the inputs the model receives in production look different from the inputs it was trained on.

Examples: a credit scoring model trained on pre-2020 income distributions receives post-pandemic applications where income volatility has doubled. A recommender trained on desktop browsing behavior receives mobile traffic with shorter sessions and different click patterns. A fraud detector trained on card-present transactions receives a surge of card-not-present transactions during a holiday sale.

The Kolmogorov-Smirnov test is the workhorse for detecting univariate data drift. It compares two distributions and returns a statistic measuring their maximum divergence:

from dataclasses import dataclass

import numpy as np
from scipy import stats


@dataclass
class DriftResult:
    feature: str
    statistic: float
    p_value: float
    is_drifted: bool


def detect_data_drift(
    reference: dict[str, np.ndarray],
    current: dict[str, np.ndarray],
    p_threshold: float = 0.01,
) -> list[DriftResult]:
    """
    Compare feature distributions between reference (training) and current
    (production) data using the two-sample KS test.

    Args:
        reference: Feature name -> array of values from training data.
        current: Feature name -> array of values from production window.
        p_threshold: P-value below which drift is flagged. Use 0.01, not 0.05.
            With large production samples, 0.05 triggers on statistically
            significant but practically irrelevant shifts.

    Returns:
        List of DriftResult, one per feature, sorted by severity.
    """
    results: list[DriftResult] = []

    for feature_name in reference:
        if feature_name not in current:
            # Feature missing entirely — that is degradation, not drift
            results.append(DriftResult(
                feature=feature_name, statistic=1.0,
                p_value=0.0, is_drifted=True,
            ))
            continue

        ref_values = reference[feature_name]
        cur_values = current[feature_name]

        ks_stat, p_value = stats.ks_2samp(ref_values, cur_values)
        results.append(DriftResult(
            feature=feature_name,
            statistic=ks_stat,
            p_value=p_value,
            is_drifted=p_value < p_threshold,
        ))

    # Sort by KS statistic descending — worst drift first
    results.sort(key=lambda r: r.statistic, reverse=True)
    return results

The p-threshold of 0.01 deserves explanation. With production samples of 10,000+ observations, a KS test at the conventional 0.05 threshold will flag tiny distribution shifts that have no practical impact on model performance. Use 0.01 or lower, and always pair statistical drift detection with a check on the actual prediction impact.

Population Stability Index (PSI)

The PSI is more interpretable than a p-value for stakeholder reporting. It quantifies how much a distribution has shifted by comparing bin proportions between reference and current data:

def compute_psi(
    reference: np.ndarray,
    current: np.ndarray,
    n_bins: int = 10,
    eps: float = 1e-4,
) -> float:
    """
    Population Stability Index.

    PSI < 0.1: no significant shift
    PSI 0.1–0.2: moderate shift, investigate
    PSI > 0.2: significant shift, action required

    Uses quantile-based binning from the reference distribution
    to handle skewed features correctly.
    """
    # Create bins from reference distribution (not uniform bins)
    quantiles = np.linspace(0, 100, n_bins + 1)
    bin_edges = np.percentile(reference, quantiles)
    bin_edges[0] = -np.inf
    bin_edges[-1] = np.inf

    ref_counts = np.histogram(reference, bins=bin_edges)[0]
    cur_counts = np.histogram(current, bins=bin_edges)[0]

    # Normalize to proportions, add epsilon to avoid division by zero
    ref_pct = ref_counts / len(reference) + eps
    cur_pct = cur_counts / len(current) + eps

    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return float(psi)

Use quantile-based binning from the reference distribution, not uniform bins. Uniform bins produce misleading PSI values for skewed features because most bins are empty.

Feature Degradation: The Silent Killer

Drift detection assumes features are present and valid but distributed differently. Feature degradation is worse — a feature has stopped being informative entirely. An upstream ETL pipeline changed its output format. A third-party API started returning nulls. A database migration silently set a column to its default value.

These failures produce no errors. The model receives a valid input, makes a prediction, and returns a response. The prediction is garbage because a critical feature is now constant, null, or nonsensical. Here is a detector that catches the three most common degradation patterns:

@dataclass
class DegradationAlert:
    feature: str
    alert_type: str  # "null_rate", "low_variance", "cardinality_collapse"
    current_value: float
    baseline_value: float
    message: str


def detect_feature_degradation(
    reference: dict[str, np.ndarray],
    current: dict[str, np.ndarray],
    null_rate_threshold: float = 0.05,
    variance_ratio_threshold: float = 0.1,
    cardinality_ratio_threshold: float = 0.3,
) -> list[DegradationAlert]:
    """
    Detect features that have degraded — not shifted in distribution,
    but broken structurally.

    Checks three failure modes:
    1. Null rate spike: feature that was <1% null is now >5% null
    2. Variance collapse: feature variance dropped below 10% of training variance
    3. Cardinality collapse: number of unique values dropped below 30% of training
    """
    alerts: list[DegradationAlert] = []

    for feature_name, ref_values in reference.items():
        cur_values = current.get(feature_name)
        if cur_values is None:
            alerts.append(DegradationAlert(
                feature=feature_name, alert_type="missing_feature",
                current_value=0.0, baseline_value=1.0,
                message=f"Feature '{feature_name}' is entirely absent from "
                        f"production data.",
            ))
            continue

        # 1. Null rate spike
        ref_null_rate = np.isnan(ref_values.astype(float)).mean()
        cur_null_rate = np.isnan(cur_values.astype(float)).mean()
        if cur_null_rate > null_rate_threshold and cur_null_rate > ref_null_rate * 3:
            alerts.append(DegradationAlert(
                feature=feature_name, alert_type="null_rate",
                current_value=cur_null_rate, baseline_value=ref_null_rate,
                message=f"Null rate for '{feature_name}' jumped from "
                        f"{ref_null_rate:.1%} to {cur_null_rate:.1%}.",
            ))

        # 2. Variance collapse
        ref_var = np.nanvar(ref_values.astype(float))
        cur_var = np.nanvar(cur_values.astype(float))
        if ref_var > 0:
            variance_ratio = cur_var / ref_var
            if variance_ratio < variance_ratio_threshold:
                alerts.append(DegradationAlert(
                    feature=feature_name, alert_type="low_variance",
                    current_value=cur_var, baseline_value=ref_var,
                    message=f"Variance of '{feature_name}' collapsed to "
                            f"{variance_ratio:.1%} of training variance. "
                            f"Feature may be constant.",
                ))

        # 3. Cardinality collapse
        ref_unique = len(np.unique(ref_values[~np.isnan(ref_values.astype(float))]))
        cur_unique = len(np.unique(cur_values[~np.isnan(cur_values.astype(float))]))
        if ref_unique > 10:  # Only check features with meaningful cardinality
            cardinality_ratio = cur_unique / ref_unique
            if cardinality_ratio < cardinality_ratio_threshold:
                alerts.append(DegradationAlert(
                    feature=feature_name, alert_type="cardinality_collapse",
                    current_value=cur_unique, baseline_value=ref_unique,
                    message=f"Unique values for '{feature_name}' dropped from "
                            f"{ref_unique} to {cur_unique}.",
                ))

    return alerts

Alert Thresholds: The Goldilocks Problem

Two failure modes of alerting are equally dangerous. Too sensitive: you alert on every minor distribution shift, your team develops alert fatigue, and when a real drift event occurs, nobody investigates because the alert channel is full of noise. Too lenient: you only alert on catastrophic drift, and by the time the alert fires, the model has been serving degraded predictions for weeks.

The calibration approach that works in practice:

  1. Start lenient. Set thresholds that would have caught the most severe drift event in your historical data. PSI > 0.25, KS p-value < 0.001, null rate > 10%.
  2. Track false negatives. Every time a model issue is discovered by someone other than the monitoring system, ask: which threshold would have caught this, and how much sooner?
  3. Tighten gradually. Lower thresholds until the monitoring system catches issues before humans do, then stop.
  4. Window size matters. Compare weekly production windows against the training baseline, not individual batches. Daily windows are noisy. Hourly windows are chaos.

Drift Detection Pipeline

The goal is not zero drift. The goal is catching drift that matters — drift that degrades predictions enough to affect the business metric you care about — before the business notices.