Source code for nestkit.plotting.threshold

"""Threshold optimization visualizations."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from nestkit.plotting._style import _apply_axis_limits, _get_ax

if TYPE_CHECKING:
    from matplotlib.axes import Axes

_UNIT = (0.0, 1.0)


[docs] def plot_threshold_sensitivity( results, fold_idx: int = 0, line_alpha: float = 0.8, full_range: bool = False, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, ax=None, **kwargs, ) -> Axes: """Metrics as a function of decision threshold for a single fold. Parameters ---------- results : ClassifierResults Fitted nested CV classification results object. fold_idx : int, optional Index of the outer fold to visualize. line_alpha : float, optional Opacity of metric curves. full_range : bool, optional If ``True``, set both axes to [0, 1]. xlim, ylim : tuple of float or None, optional Explicit axis limits (override *full_range*). ax : matplotlib.axes.Axes or None, optional Axes to plot on. If ``None``, a new figure is created. **kwargs Additional keyword arguments passed to the underlying matplotlib call. Returns ------- matplotlib.axes.Axes The axes with the plot. """ ax = _get_ax(ax) fr = results.fold_results_[fold_idx] if fr.threshold_result is None or fr.threshold_result.threshold_sensitivity.empty: ax.text(0.5, 0.5, "No threshold data", ha="center", va="center") return ax df = fr.threshold_result.threshold_sensitivity for col in ["sensitivity", "specificity", "precision", "recall", "f1"]: if col in df.columns: ax.plot(df["threshold"], df[col], label=col, alpha=line_alpha) if "criterion_value" in df.columns: ax.plot( df["threshold"], df["criterion_value"], label=fr.threshold_result.criterion_name or "criterion", color="black", linewidth=2, linestyle=":", ) ax.axvline(fr.threshold_result.optimal_threshold, color="red", linestyle="--", label="Optimal") ax.set_xlabel("Threshold") ax.set_ylabel("Metric value") ax.set_title(f"Threshold Sensitivity (Fold {fold_idx})") ax.legend(fontsize=7) _apply_axis_limits( ax, xlim=xlim, ylim=ylim, full_range=full_range, natural_xlim=_UNIT, natural_ylim=_UNIT ) return ax
[docs] def plot_threshold_distribution( results, bar_alpha: float = 0.7, bins: int | None = None, full_range: bool = False, xlim: tuple[float, float] | None = None, ax=None, **kwargs, ) -> Axes: """Distribution of optimal thresholds across folds. Parameters ---------- results : ClassifierResults Fitted nested CV classification results object. bar_alpha : float, optional Opacity of histogram bars. bins : int or None, optional Number of histogram bins. ``None`` uses a heuristic. full_range : bool, optional If ``True``, set x-axis to [0, 1]. xlim : tuple of float or None, optional Explicit x-axis limits (override *full_range*). ax : matplotlib.axes.Axes or None, optional Axes to plot on. If ``None``, a new figure is created. **kwargs Additional keyword arguments passed to the underlying matplotlib call. Returns ------- matplotlib.axes.Axes The axes with the plot. """ ax = _get_ax(ax) if not results.has_threshold_optimization: ax.text(0.5, 0.5, "No threshold data", ha="center", va="center") return ax thresholds = results.thresholds_per_fold_ n_bins = bins if bins is not None else max(3, len(thresholds) // 2) ax.hist(thresholds, bins=n_bins, edgecolor="black", alpha=bar_alpha) ax.axvline( np.mean(thresholds), color="red", linestyle="--", label=f"Mean={np.mean(thresholds):.3f}" ) ax.set_xlabel("Optimal threshold") ax.set_ylabel("Count") ax.set_title("Threshold Distribution Across Folds") ax.legend() _apply_axis_limits(ax, xlim=xlim, full_range=full_range, natural_xlim=_UNIT) return ax
[docs] def plot_threshold_comparison( results, full_range: bool = False, ylim: tuple[float, float] | None = None, ax=None, **kwargs, ) -> Axes: """Default vs optimized metrics side-by-side. Parameters ---------- results : ClassifierResults Fitted nested CV classification results object. full_range : bool, optional If ``True``, set y-axis to [0, 1]. ylim : tuple of float or None, optional Explicit y-axis limits (override *full_range*). ax : matplotlib.axes.Axes or None, optional Axes to plot on. If ``None``, a new figure is created. **kwargs Additional keyword arguments passed to the underlying matplotlib call. Returns ------- matplotlib.axes.Axes The axes with the plot. """ ax = _get_ax(ax) if not results.has_threshold_optimization: ax.text(0.5, 0.5, "No threshold data", ha="center", va="center") return ax comp = results.threshold_comparison() x = np.arange(len(comp)) width = 0.35 ax.bar( x - width / 2, comp["mean_default"], width, label="Default (0.5)", yerr=comp["std_default"] ) ax.bar( x + width / 2, comp["mean_optimized"], width, label="Optimized", yerr=comp["std_optimized"] ) ax.set_xticks(x) ax.set_xticklabels(comp["metric"], rotation=45, ha="right") ax.set_ylabel("Score") ax.set_title("Default vs Optimized Threshold") ax.legend() _apply_axis_limits(ax, ylim=ylim, full_range=full_range, natural_ylim=_UNIT) return ax