"""Aggregate result plots - confusion matrices, ROC, precision-recall, residuals, and more."""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from nestkit.plotting._style import _apply_axis_limits, _get_ax
if TYPE_CHECKING:
from matplotlib.axes import Axes
_UNIT = (0.0, 1.0)
[docs]
def plot_confusion_matrices(
results,
threshold: str = "default",
normalize: str | None = None,
cmap: str = "Blues",
fontsize: int = 10,
ax=None,
**kwargs,
) -> Axes:
"""Per-fold and aggregate confusion matrices.
Parameters
----------
results : ClassifierResults
Fitted nested CV classification results object.
threshold : {'default', 'optimized'}, optional
Which threshold's confusion matrices to display.
normalize : {None, 'true', 'pred', 'all'}, optional
Normalization mode. ``'true'`` normalizes by row (true label),
``'pred'`` by column (predicted label), ``'all'`` by total count.
``None`` displays raw counts.
cmap : str, optional
Colormap for the heatmaps.
fontsize : int, optional
Font size for cell values.
ax : matplotlib.axes.Axes or None, optional
Ignored (subplots are always created). Kept for API consistency.
**kwargs
Additional keyword arguments passed to ``imshow``.
Returns
-------
matplotlib.axes.Axes
The first axes of the created subplots.
"""
import matplotlib.pyplot as plt
if threshold == "optimized" and hasattr(results, "confusion_matrices_optimized_"):
cms = results.confusion_matrices_optimized_
agg = results.confusion_matrix_aggregate_optimized_
else:
cms = results.confusion_matrices_default_
agg = results.confusion_matrix_aggregate_default_
def _normalize_cm(cm):
cm = cm.astype(float)
if normalize == "true":
row_sums = cm.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1
return cm / row_sums
if normalize == "pred":
col_sums = cm.sum(axis=0, keepdims=True)
col_sums[col_sums == 0] = 1
return cm / col_sums
if normalize == "all":
total = cm.sum()
return cm / total if total > 0 else cm
return cm
n = len(cms) + 1
_fig, axes = plt.subplots(1, n, figsize=(3 * n, 3))
if n == 1:
axes = [axes]
labels = getattr(results, "classes_", None)
n_classes = cms[0].shape[0]
tick_labels = labels if labels is not None else list(range(n_classes))
for i, cm in enumerate(cms):
display = _normalize_cm(cm)
axes[i].imshow(display, cmap=cmap, **kwargs)
axes[i].set_title(f"Fold {i}")
axes[i].set_xticks(range(n_classes))
axes[i].set_xticklabels(tick_labels)
axes[i].set_yticks(range(n_classes))
axes[i].set_yticklabels(tick_labels)
axes[i].set_xlabel("Predicted")
axes[i].set_ylabel("True")
for r in range(display.shape[0]):
for c in range(display.shape[1]):
val = f"{display[r, c]:.2f}" if normalize else str(int(cm[r, c]))
axes[i].text(c, r, val, ha="center", va="center", fontsize=fontsize)
display_agg = _normalize_cm(agg)
axes[-1].imshow(display_agg, cmap=cmap, **kwargs)
axes[-1].set_title("Aggregate")
axes[-1].set_xticks(range(n_classes))
axes[-1].set_xticklabels(tick_labels)
axes[-1].set_yticks(range(n_classes))
axes[-1].set_yticklabels(tick_labels)
axes[-1].set_xlabel("Predicted")
axes[-1].set_ylabel("True")
for r in range(display_agg.shape[0]):
for c in range(display_agg.shape[1]):
val = f"{display_agg[r, c]:.2f}" if normalize else str(int(agg[r, c]))
axes[-1].text(c, r, val, ha="center", va="center", fontsize=fontsize)
plt.tight_layout()
return axes[0]
[docs]
def plot_roc_curves(
results,
fold_alpha: float = 0.4,
mean_color: str = "b",
mean_lw: float = 2,
band_alpha: float = 0.2,
full_range: bool = False,
ylim: tuple[float, float] | None = None,
xlim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Per-fold ROC curves with mean and confidence-interval band.
Parameters
----------
results : ClassifierResults
Fitted nested CV classification results object.
fold_alpha : float, optional
Opacity of individual fold curves.
mean_color : str, optional
Color of the mean ROC curve.
mean_lw : float, optional
Line width of the mean ROC curve.
band_alpha : float, optional
Opacity of the +/- 1 std band.
full_range : bool, optional
If ``True``, set both axes to [0, 1].
ylim, xlim : tuple of float or None, optional
Explicit axis limits (override *full_range*).
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
from sklearn.metrics import roc_curve
ax = _get_ax(ax)
mean_fpr = np.linspace(0, 1, 100)
tprs = []
for fr in results.fold_results_:
proba = fr.y_proba_calibrated if fr.y_proba_calibrated is not None else fr.y_proba_raw
if proba.ndim == 2:
proba = proba[:, 1]
fpr, tpr, _ = roc_curve(fr.y_true, proba)
ax.plot(fpr, tpr, alpha=fold_alpha, lw=1)
tprs.append(np.interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
std_tpr = np.std(tprs, axis=0)
ax.plot(mean_fpr, mean_tpr, color=mean_color, lw=mean_lw, label="Mean ROC")
ax.fill_between(
mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, alpha=band_alpha, color=mean_color
)
ax.plot([0, 1], [0, 1], "k--", lw=1)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("ROC Curves")
ax.legend(loc="lower right")
_apply_axis_limits(
ax, xlim=xlim, ylim=ylim, full_range=full_range, natural_xlim=_UNIT, natural_ylim=_UNIT
)
return ax
[docs]
def plot_precision_recall_curves(
results,
fold_alpha: float = 0.5,
full_range: bool = False,
ylim: tuple[float, float] | None = None,
xlim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Per-fold precision-recall curves.
Parameters
----------
results : ClassifierResults
Fitted nested CV classification results object.
fold_alpha : float, optional
Opacity of individual fold curves.
full_range : bool, optional
If ``True``, set both axes to [0, 1].
ylim, xlim : tuple of float or None, optional
Explicit axis limits (override *full_range*).
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
from sklearn.metrics import precision_recall_curve
ax = _get_ax(ax)
for i, fr in enumerate(results.fold_results_):
proba = fr.y_proba_calibrated if fr.y_proba_calibrated is not None else fr.y_proba_raw
if proba.ndim == 2:
proba = proba[:, 1]
precision, recall, _ = precision_recall_curve(fr.y_true, proba)
ax.plot(recall, precision, alpha=fold_alpha, label=f"Fold {i}")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curves")
ax.legend(fontsize=7)
_apply_axis_limits(
ax, xlim=xlim, ylim=ylim, full_range=full_range, natural_xlim=_UNIT, natural_ylim=_UNIT
)
return ax
[docs]
def plot_rank_stability(
results,
top_k: int = 5,
fold_alpha: float = 0.6,
ylim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Configuration rank stability across folds.
Parameters
----------
results : ClassifierResults or RegressorResults
Fitted nested CV results object.
top_k : int, optional
Number of top configurations to display per fold.
fold_alpha : float, optional
Opacity of fold lines.
ylim : tuple of float or None, optional
Explicit y-axis limits.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
for i, report in enumerate(results.inner_reports_):
df = report.ranking()
if "mean_test_score" in df.columns:
ax.plot(
range(min(top_k, len(df))),
df["mean_test_score"].values[:top_k],
marker="o",
alpha=fold_alpha,
label=f"Fold {i}",
)
ax.set_xlabel("Configuration Rank")
ax.set_ylabel("Mean Test Score")
ax.set_title("Inner CV Rank Stability")
ax.legend(fontsize=7)
_apply_axis_limits(ax, ylim=ylim)
return ax
[docs]
def plot_residuals(
results,
fold_idx: int | list[int] | None = None,
bins: int = 30,
fold_alpha: float = 0.5,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Residual distributions per fold.
Parameters
----------
results : RegressorResults
Fitted nested CV regression results object.
fold_idx : int, list of int, or None, optional
Outer fold index or indices to plot. If ``None`` (default), all
folds are shown.
bins : int, optional
Number of histogram bins.
fold_alpha : float, optional
Opacity of fold histograms.
xlim, ylim : tuple of float or None, optional
Explicit axis limits.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
if fold_idx is None:
indices = list(range(len(results.fold_results_)))
elif isinstance(fold_idx, int):
indices = [fold_idx]
else:
indices = list(fold_idx)
for i in indices:
fr = results.fold_results_[i]
ax.hist(fr.residuals, bins=bins, alpha=fold_alpha, label=f"Fold {i}")
ax.set_xlabel("Residual")
ax.set_ylabel("Count")
ax.set_title("Residual Distributions")
ax.legend(fontsize=7)
_apply_axis_limits(ax, xlim=xlim, ylim=ylim, full_range=False)
return ax
[docs]
def plot_predicted_vs_actual(
results,
point_alpha: float = 0.4,
point_size: float = 12,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Scatter of predicted vs actual values with identity line.
Parameters
----------
results : RegressorResults
Fitted nested CV regression results object.
point_alpha : float, optional
Opacity of scatter points.
point_size : float, optional
Size of scatter points.
xlim, ylim : tuple of float or None, optional
Explicit axis limits.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
preds = results.predictions_
ax.scatter(preds["y_true"], preds["y_pred"], alpha=point_alpha, s=point_size)
lims = [
min(preds["y_true"].min(), preds["y_pred"].min()),
max(preds["y_true"].max(), preds["y_pred"].max()),
]
ax.plot(lims, lims, "r--", lw=1)
ax.set_xlabel("Actual")
ax.set_ylabel("Predicted")
ax.set_title("Predicted vs Actual")
_apply_axis_limits(ax, xlim=xlim, ylim=ylim, full_range=False)
return ax
[docs]
def plot_prediction_intervals(
results,
band_alpha: float = 0.25,
point_size: float = 8,
ylim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Predictions with interval bands.
Parameters
----------
results : RegressorResults
Fitted nested CV regression results object.
band_alpha : float, optional
Opacity of the prediction interval band.
point_size : float, optional
Size of scatter points.
ylim : tuple of float or None, optional
Explicit y-axis limits.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
preds = results.predictions_
if "pi_lower" not in preds.columns:
ax.text(
0.5, 0.5, "No prediction intervals", ha="center", va="center", transform=ax.transAxes
)
return ax
sorted_idx = preds["y_true"].argsort()
x = np.arange(len(preds))
ax.fill_between(
x,
preds["pi_lower"].values[sorted_idx],
preds["pi_upper"].values[sorted_idx],
alpha=band_alpha,
color="blue",
label="Prediction Interval",
)
ax.scatter(x, preds["y_true"].values[sorted_idx], s=point_size, color="red", label="Actual")
ax.scatter(
x, preds["y_pred"].values[sorted_idx], s=point_size, color="blue", label="Predicted"
)
ax.set_xlabel("Sample (sorted by actual)")
ax.set_ylabel("Value")
ax.set_title("Prediction Intervals")
ax.legend(fontsize=7)
_apply_axis_limits(ax, ylim=ylim, full_range=False)
return ax
[docs]
def plot_residual_qq(
results,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""QQ plot of pooled residuals.
Parameters
----------
results : RegressorResults
Fitted nested CV regression results object.
xlim, ylim : tuple of float or None, optional
Explicit axis limits.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
from scipy import stats
ax = _get_ax(ax)
all_residuals = np.concatenate([fr.residuals for fr in results.fold_results_])
stats.probplot(all_residuals, dist="norm", plot=ax)
ax.set_title("Residual QQ Plot")
_apply_axis_limits(ax, xlim=xlim, ylim=ylim, full_range=False)
return ax