Source code for nestkit.plotting.calibration

"""Calibration visualizations."""

from __future__ import annotations

from typing import TYPE_CHECKING

from nestkit._validation import extract_positive_proba
from nestkit.calibration.diagnostics import CalibrationDiagnostics
from nestkit.plotting._style import _apply_axis_limits, _get_ax

if TYPE_CHECKING:
    from matplotlib.axes import Axes

_UNIT = (0.0, 1.0)


[docs] def plot_calibration_curves( results, fold_idx: int | list[int] | None = None, fold_alpha: float = 0.4, full_range: bool = False, ylim: tuple[float, float] | None = None, xlim: tuple[float, float] | None = None, ax=None, **kwargs, ) -> Axes: """Reliability diagrams showing raw vs calibrated probabilities per fold. Parameters ---------- results : ClassifierResults Fitted nested CV classification results object. fold_idx : int, list of int, or None, optional Outer fold index or indices to plot. If ``None`` (default), all folds are shown. fold_alpha : float, optional Opacity of individual fold curves. full_range : bool, optional If ``True``, set both axes to [0, 1]. ylim, xlim : tuple of float or None, optional Explicit axis limits (override *full_range*). ax : matplotlib.axes.Axes or None, optional Axes to plot on. If ``None``, a new figure is created. **kwargs Additional keyword arguments passed to the underlying matplotlib call. Returns ------- matplotlib.axes.Axes The axes with the plot. """ import matplotlib.pyplot as plt ax = _get_ax(ax) ax.plot([0, 1], [0, 1], "k--", label="Perfect") # Determine which folds to plot if fold_idx is None: selected = results.fold_results_ elif isinstance(fold_idx, int): selected = [fr for fr in results.fold_results_ if fr.fold_idx == fold_idx] else: fold_set = set(fold_idx) selected = [fr for fr in results.fold_results_ if fr.fold_idx in fold_set] # Use same color per fold, different markers for raw vs calibrated cmap = plt.rcParams["axes.prop_cycle"].by_key()["color"] for i, fr in enumerate(selected): color = cmap[i % len(cmap)] p_raw = extract_positive_proba(fr.y_proba_raw) diag_data = CalibrationDiagnostics.reliability_diagram_data(fr.y_true, p_raw) valid = diag_data.dropna(subset=["fraction_positive"]) ax.plot( valid["mean_predicted"], valid["fraction_positive"], "o-", color=color, alpha=fold_alpha, label=f"Raw fold {fr.fold_idx}", ) if fr.y_proba_calibrated is not None: p_cal = extract_positive_proba(fr.y_proba_calibrated) diag_cal = CalibrationDiagnostics.reliability_diagram_data(fr.y_true, p_cal) valid_cal = diag_cal.dropna(subset=["fraction_positive"]) ax.plot( valid_cal["mean_predicted"], valid_cal["fraction_positive"], "s--", color=color, alpha=fold_alpha, label=f"Cal fold {fr.fold_idx}", ) ax.set_xlabel("Mean predicted probability") ax.set_ylabel("Fraction of positives") ax.set_title("Reliability Diagram") ax.legend(fontsize=7) _apply_axis_limits( ax, xlim=xlim, ylim=ylim, full_range=full_range, natural_xlim=_UNIT, natural_ylim=_UNIT ) return ax
[docs] def plot_calibration_improvement( results, annot: bool = False, annot_fmt: str = ".3f", full_range: bool = False, ylim: tuple[float, float] | None = None, ax=None, **kwargs, ) -> Axes: """Paired bar plot of ECE before vs after calibration per fold. Shows raw and calibrated ECE side by side for each fold, with the gap (improvement) annotated above each pair. Mean and standard deviation of the gap are reported in the legend. Parameters ---------- results : ClassifierResults Fitted nested CV classification results object. annot : bool, optional If ``True``, annotate each bar with its ECE value and display the gap above each pair. annot_fmt : str, optional Format string for annotations. full_range : bool, optional If ``True``, set y-axis to [0, 1]. ylim : tuple of float or None, optional Explicit y-axis limits (override *full_range*). ax : matplotlib.axes.Axes or None, optional Axes to plot on. If ``None``, a new figure is created. **kwargs Additional keyword arguments passed to the underlying matplotlib ``bar`` calls. Returns ------- matplotlib.axes.Axes The axes with the plot. """ import numpy as np ax = _get_ax(ax) if not results.has_calibration: ax.text(0.5, 0.5, "No calibration data", ha="center", va="center") return ax cal_df = results.calibration_summary_ if "ece_raw" not in cal_df.columns: ax.text(0.5, 0.5, "No ECE data", ha="center", va="center") return ax folds = cal_df["fold_idx"].values ece_raw = cal_df["ece_raw"].values ece_cal = cal_df["ece_calibrated"].values gaps = ece_raw - ece_cal n = len(folds) x = np.arange(n) width = 0.35 gap_mean = np.mean(gaps) gap_std = np.std(gaps, ddof=1) if n > 1 else 0.0 ax.bar( x - width / 2, ece_raw, width, label="ECE raw", **kwargs, ) ax.bar( x + width / 2, ece_cal, width, label="ECE calibrated", **kwargs, ) # Invisible bar entry for gap stats in legend ax.bar([], [], width=0, label=f"Gap: {gap_mean:{annot_fmt}} \u00b1 {gap_std:{annot_fmt}}") if annot: for i in range(n): ax.text( x[i] - width / 2, ece_raw[i], format(ece_raw[i], annot_fmt), ha="center", va="bottom", fontsize=7, ) ax.text( x[i] + width / 2, ece_cal[i], format(ece_cal[i], annot_fmt), ha="center", va="bottom", fontsize=7, ) top = max(ece_raw[i], ece_cal[i]) ax.text( x[i], top * 1.05, f"\u0394{format(gaps[i], annot_fmt)}", ha="center", va="bottom", fontsize=7, color="gray", ) ax.set_xticks(x) ax.set_xticklabels([f"Fold {int(f)}" for f in folds]) ax.set_ylabel("ECE") ax.set_title("Calibration Improvement") ax.legend(fontsize=8) _apply_axis_limits(ax, ylim=ylim, full_range=full_range, natural_ylim=_UNIT) return ax