"""Nested cross-validation estimator for classification tasks.
Extends :class:`~nestkit._base._BaseNestedCV` with optional post-hoc
probability calibration (Platt scaling, isotonic regression, beta
calibration, Venn-ABERS) and decision-threshold optimization.
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
from sklearn.base import clone
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
confusion_matrix,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
from sklearn.model_selection import check_cv
from nestkit._base import _BaseNestedCV
from nestkit._constants import _EPS
from nestkit._validation import (
extract_positive_proba,
validate_calibration_method,
validate_conformal_params,
validate_threshold_params,
)
from nestkit.calibration.calibrators import PostHocCalibrator
from nestkit.calibration.diagnostics import CalibrationDiagnostics
from nestkit.results.classifier_results import ClassifierOuterFoldResult, ClassifierResults
from nestkit.thresholding.criteria import (
balanced_accuracy_criterion,
cost_sensitive,
f_beta_criterion,
precision_at_recall,
youden_j,
)
from nestkit.thresholding.strategies import (
FoldSpecificThreshold,
PooledThreshold,
)
logger = logging.getLogger("nestkit")
[docs]
class NestedCVClassifier(_BaseNestedCV):
"""Nested cross-validation for classification tasks.
Supports binary and multiclass classification. Extends
:class:`~nestkit._base._BaseNestedCV` with optional post-hoc
probability calibration and decision-threshold optimization.
Both features are disabled by default and must be explicitly enabled.
When calibration is enabled, out-of-fold (OOF) predictions from the
inner CV are used to fit a calibrator, which is then applied to the
outer test-set probabilities. When threshold optimization is enabled,
the optimal decision boundary is selected on the calibrated (or raw)
OOF probabilities.
Parameters
----------
estimator : estimator object
A scikit-learn compatible classifier that implements ``fit``
and ``predict_proba``. Cloned for each outer fold.
param_grid : dict or list of dict
Hyperparameter search space. See
:class:`~sklearn.model_selection.GridSearchCV`.
search_strategy : {'grid', 'random', 'bayesian'}, default='grid'
Inner hyperparameter search strategy.
outer_cv : int, cross-validation generator, or iterable, default=5
Outer cross-validation splitting strategy.
inner_cv : int, cross-validation generator, or iterable, default=5
Inner cross-validation splitting strategy.
scoring : str, callable, list, tuple, or dict, default=None
Scoring metric(s) for the inner search.
refit : bool or str, default=True
Whether to refit on the full outer training set.
return_train_score : bool, default=False
Whether to include training scores in inner CV results.
return_estimator : bool, default=True
Whether to store fitted estimators per outer fold.
error_score : 'raise' or numeric, default='raise'
Value assigned on inner CV fitting errors.
n_jobs_outer : int or None, default=None
Number of parallel jobs for outer folds.
n_jobs_inner : int or None, default=None
Number of parallel jobs for inner search.
verbose : int, default=0
Verbosity level.
random_state : int, RandomState instance, or None, default=None
Random state for reproducibility.
callbacks : list of callback objects or None, default=None
:class:`~nestkit.FoldCallback` instances for monitoring.
pre_dispatch : int or str, default='2*n_jobs'
Controls job dispatch for parallel execution.
calibration_method : {'sigmoid', 'isotonic', 'beta', 'venn_abers'} or None, default=None
Post-hoc calibration method. If ``None``, no calibration is
applied. ``'sigmoid'`` corresponds to Platt scaling,
``'isotonic'`` to isotonic regression, ``'beta'`` to beta
calibration, and ``'venn_abers'`` to Venn-ABERS prediction.
threshold_strategy : {'pooled', 'fold_specific'} or None, default=None
Threshold optimization strategy. If ``None``, no threshold
optimization is performed. ``'pooled'`` selects a single
threshold from all OOF predictions; ``'fold_specific'`` selects
a per-fold threshold.
threshold_criterion : str or callable, default='youden'
Criterion for threshold selection. Built-in options:
``'youden'``, ``'f_beta'``, ``'cost'``, ``'balanced_accuracy'``,
``'precision_at_recall'``. A custom callable must accept
``(y_true, y_proba, threshold)`` and return a ``float`` to be
maximised.
threshold_beta : float, default=1.0
Beta parameter for the F-beta criterion. Only used when
``threshold_criterion='f_beta'``.
cost_matrix : array-like of shape (2, 2) or None, default=None
Cost matrix ``[[TN_cost, FP_cost], [FN_cost, TP_cost]]`` for
cost-sensitive threshold optimization. Required when
``threshold_criterion='cost'``.
min_recall : float or None, default=None
Minimum recall constraint for the ``'precision_at_recall'``
criterion. Required when ``threshold_criterion='precision_at_recall'``.
calibration_cv : int, cross-validation generator, or None, default=None
CV strategy for generating OOF calibration predictions. If
``None``, uses the same ``inner_cv`` strategy. Note that when
``inner_cv`` is an integer, a **new** splitter instance is
created for the calibration OOF loop, which may produce
different fold assignments than the inner hyperparameter search.
conformal_prediction : bool, default=False
If ``True``, compute CV+ Mondrian conformal prediction sets
using inner out-of-fold probabilities (calibrated if
calibration is enabled). Each outer fold gets its own per-class
q-hat threshold, applied to the held-out test fold.
conformal_alpha : float, default=0.1
Significance level (miscoverage rate) for conformal prediction.
Target coverage is ``1 - alpha``. Must be in ``(0, 1)``.
Notes
-----
Enabling calibration and/or threshold optimization roughly doubles
computation time per outer fold, as the inner CV folds must be re-run
to produce OOF probability estimates for the calibrator and threshold
optimizer.
For multiclass tasks, calibration is applied independently per class
using a one-vs-rest (OVR) decomposition. After calibration the
per-class probabilities are renormalized to sum to 1. Because each
calibrator is fitted on a marginal binary problem, the resulting
multiclass probabilities may not be jointly well-calibrated -- this
is a known limitation of OVR calibration approaches.
Examples
--------
Basic classification:
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.ensemble import RandomForestClassifier
>>> from nestkit import NestedCVClassifier
>>> X, y = load_breast_cancer(return_X_y=True)
>>> ncv = NestedCVClassifier(
... estimator=RandomForestClassifier(random_state=42),
... param_grid={"n_estimators": [50, 100], "max_depth": [3, 5]},
... outer_cv=5, inner_cv=3, random_state=42,
... )
>>> ncv.fit(X, y) # doctest: +SKIP
>>> print(ncv.results_.summary_default_) # doctest: +SKIP
With calibration and threshold optimization:
>>> ncv = NestedCVClassifier(
... estimator=RandomForestClassifier(random_state=42),
... param_grid={"n_estimators": [50, 100]},
... outer_cv=5, inner_cv=3,
... calibration_method="isotonic",
... threshold_strategy="pooled",
... threshold_criterion="youden",
... random_state=42,
... )
>>> ncv.fit(X, y) # doctest: +SKIP
See Also
--------
nestkit.NestedCVRegressor : Regression-specific nested CV.
nestkit.calibration.PostHocCalibrator : Standalone calibrator.
nestkit.thresholding.strategies.PooledThreshold : Pooled threshold strategy.
"""
def __init__(
self,
estimator,
param_grid,
*,
search_strategy="grid",
outer_cv=5,
inner_cv=5,
scoring=None,
refit=True,
return_train_score=False,
return_estimator=True,
error_score="raise",
n_jobs_outer=None,
n_jobs_inner=None,
verbose=0,
random_state=None,
callbacks=None,
pre_dispatch="2*n_jobs",
calibration_method=None,
threshold_strategy=None,
threshold_criterion="youden",
threshold_beta=1.0,
cost_matrix=None,
min_recall=None,
calibration_cv=None,
conformal_prediction=False,
conformal_alpha=0.1,
):
super().__init__(
estimator=estimator,
param_grid=param_grid,
search_strategy=search_strategy,
outer_cv=outer_cv,
inner_cv=inner_cv,
scoring=scoring,
refit=refit,
return_train_score=return_train_score,
return_estimator=return_estimator,
error_score=error_score,
n_jobs_outer=n_jobs_outer,
n_jobs_inner=n_jobs_inner,
verbose=verbose,
random_state=random_state,
callbacks=callbacks,
pre_dispatch=pre_dispatch,
)
self.calibration_method = calibration_method
self.threshold_strategy = threshold_strategy
self.threshold_criterion = threshold_criterion
self.threshold_beta = threshold_beta
self.cost_matrix = cost_matrix
self.min_recall = min_recall
self.calibration_cv = calibration_cv
self.conformal_prediction = conformal_prediction
self.conformal_alpha = conformal_alpha
[docs]
def fit(self, X, y, groups=None, **fit_params):
"""Run nested cross-validation with optional calibration and thresholding.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target labels.
groups : array-like of shape (n_samples,) or None, default=None
Group labels for group-aware CV splitters.
**fit_params : dict
Additional keyword arguments forwarded to the estimator's
``fit`` method.
Returns
-------
self
The fitted estimator. Results are accessible via
:attr:`results_`.
Raises
------
ValueError
If calibration or threshold parameters are invalid.
"""
validate_calibration_method(self.calibration_method)
validate_threshold_params(
self.threshold_strategy,
self.threshold_criterion,
self.cost_matrix,
self.min_recall,
)
validate_conformal_params(self.conformal_prediction, self.conformal_alpha)
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
return super().fit(X, y, groups=groups, **fit_params)
def _build_results_container(self) -> type:
return ClassifierResults
def _post_inner_processing(self, search, X_train, y_train, groups_train, **fit_params) -> dict:
"""Phase 2 + Phase 3: calibration and threshold optimization.
Note: The OOF loop uses ``search.best_params_`` which were selected
using all of ``X_train``. The OOF validation folds therefore
influenced hyperparameter selection. This is a widely accepted
approximation -- the alternative (triple-nested CV) is
computationally prohibitive for most practical use cases.
"""
artifacts: dict[str, Any] = {
"calibrator": None,
"calibrators_ovr": None,
"optimal_threshold": 0.5,
"optimal_thresholds_ovr": None,
"fold_thresholds": None,
"threshold_result": None,
"oof_probas_raw": None,
"oof_probas_calibrated": None,
"oof_y_true": None,
"conformal_result": None,
}
# Fast path
if (
self.calibration_method is None
and self.threshold_strategy is None
and not self.conformal_prediction
):
return artifacts
# Slow path: collect inner OOF predictions (always refit)
cal_cv = check_cv(self.calibration_cv or self.inner_cv, y_train, classifier=True)
best_params = search.best_params_
base_estimator = clone(self.estimator).set_params(**best_params)
oof_probas: list[np.ndarray] = []
oof_y_true: list[np.ndarray] = []
for inner_train_idx, inner_val_idx in cal_cv.split(X_train, y_train, groups_train):
est_j = clone(base_estimator)
est_j.fit(X_train[inner_train_idx], y_train[inner_train_idx], **fit_params)
oof_probas.append(est_j.predict_proba(X_train[inner_val_idx]))
oof_y_true.append(y_train[inner_val_idx])
oof_probas_all = np.concatenate(oof_probas)
oof_y_all = np.concatenate(oof_y_true)
artifacts["oof_probas_raw"] = oof_probas_all
artifacts["oof_y_true"] = oof_y_all
n_classes = oof_probas_all.shape[1] if oof_probas_all.ndim == 2 else 2
is_binary = n_classes == 2
# --- Phase 2: Calibration ---
if self.calibration_method is not None:
if is_binary:
calibrator = PostHocCalibrator(method=self.calibration_method)
p_pos = extract_positive_proba(oof_probas_all)
calibrator.fit(p_pos, oof_y_all)
artifacts["calibrator"] = calibrator
cal_probas_all = calibrator.predict_proba(p_pos)
cal_probas_per_fold = [
calibrator.predict_proba(extract_positive_proba(p)) for p in oof_probas
]
else:
calibrators_ovr = []
cal_probas_all = np.zeros_like(oof_probas_all)
cal_probas_per_fold = [np.zeros_like(p) for p in oof_probas]
for c in range(n_classes):
y_binary = (oof_y_all == self.classes_[c]).astype(int)
p_c = oof_probas_all[:, c]
cal_c = PostHocCalibrator(method=self.calibration_method)
cal_c.fit(p_c, y_binary)
calibrators_ovr.append(cal_c)
cal_probas_all[:, c] = cal_c.predict_proba(p_c)[:, 1]
for j, p_fold in enumerate(oof_probas):
cal_probas_per_fold[j][:, c] = cal_c.predict_proba(p_fold[:, c])[:, 1]
# Renormalize
row_sums = cal_probas_all.sum(axis=1, keepdims=True)
cal_probas_all /= row_sums + _EPS
for j in range(len(cal_probas_per_fold)):
rs = cal_probas_per_fold[j].sum(axis=1, keepdims=True)
cal_probas_per_fold[j] /= rs + _EPS
artifacts["calibrators_ovr"] = calibrators_ovr
else:
cal_probas_all = oof_probas_all
cal_probas_per_fold = oof_probas
artifacts["oof_probas_calibrated"] = cal_probas_all
# --- Phase 3: Threshold optimization ---
if self.threshold_strategy is not None:
criterion_fn = self._resolve_criterion()
criterion_name = (
self.threshold_criterion
if isinstance(self.threshold_criterion, str)
else getattr(self.threshold_criterion, "__name__", "custom")
)
if is_binary:
cal_p_pos_per_fold = [extract_positive_proba(p) for p in cal_probas_per_fold]
if self.threshold_strategy == "fold_specific":
tr = FoldSpecificThreshold.optimize(
oof_y_true, cal_p_pos_per_fold, criterion_fn, criterion_name
)
else:
tr = PooledThreshold.optimize(
oof_y_true, cal_p_pos_per_fold, criterion_fn, criterion_name
)
artifacts["optimal_threshold"] = tr.optimal_threshold
artifacts["threshold_result"] = tr
else:
# Multiclass OVR: apply threshold strategy per class
thresholds_ovr = []
for c in range(n_classes):
y_binary_per_fold = [(y == self.classes_[c]).astype(int) for y in oof_y_true]
p_c_per_fold = [p[:, c] for p in cal_probas_per_fold]
if self.threshold_strategy == "fold_specific":
tr_c = FoldSpecificThreshold.optimize(
y_binary_per_fold, p_c_per_fold, criterion_fn, criterion_name
)
else:
tr_c = PooledThreshold.optimize(
y_binary_per_fold, p_c_per_fold, criterion_fn, criterion_name
)
thresholds_ovr.append(tr_c.optimal_threshold)
artifacts["optimal_thresholds_ovr"] = np.array(thresholds_ovr)
# --- Phase 2c: Conformal prediction ---
if self.conformal_prediction:
from nestkit.conformal.classifier_conformal import MondrianClassifierConformal
conformal_result = MondrianClassifierConformal.fit(
oof_probas=cal_probas_all,
oof_y_true=oof_y_all,
classes=self.classes_,
alpha=self.conformal_alpha,
)
artifacts["conformal_result"] = conformal_result
return artifacts
def _evaluate_outer_fold(self, estimator, X_test, y_test, artifacts) -> dict:
"""Evaluate best estimator on outer test fold."""
raw_proba = estimator.predict_proba(X_test)
n_classes = raw_proba.shape[1]
is_binary = n_classes == 2
# Apply calibration
cal_proba = self._apply_calibration(raw_proba, artifacts)
# Default predictions
if is_binary:
effective_proba = extract_positive_proba(cal_proba)
y_pred_default = (effective_proba >= 0.5).astype(int)
else:
y_pred_default = self.classes_[np.argmax(cal_proba, axis=1)]
scores_default = self._compute_metrics(y_test, y_pred_default, cal_proba, is_binary)
cm_default = confusion_matrix(y_test, y_pred_default)
has_calibration = (
artifacts["calibrator"] is not None or artifacts.get("calibrators_ovr") is not None
)
result = {
"y_true": y_test,
"y_proba_raw": raw_proba,
"y_proba_calibrated": cal_proba if has_calibration else None,
"y_pred_default": y_pred_default,
"scores_default": scores_default,
"confusion_matrix_default": cm_default,
"y_pred_optimized": None,
"scores_optimized": None,
"confusion_matrix_optimized": None,
"calibration_diagnostics": None,
}
# Calibration diagnostics on held-out test data (binary only for now)
if has_calibration and is_binary:
raw_p = extract_positive_proba(raw_proba)
cal_p = extract_positive_proba(cal_proba)
diag = CalibrationDiagnostics
result["calibration_diagnostics"] = {
"ece_raw": diag.expected_calibration_error(y_test, raw_p),
"ece_calibrated": diag.expected_calibration_error(y_test, cal_p),
"mce_raw": diag.maximum_calibration_error(y_test, raw_p),
"mce_calibrated": diag.maximum_calibration_error(y_test, cal_p),
"brier_raw": diag.brier_score(y_test, raw_p),
"brier_calibrated": diag.brier_score(y_test, cal_p),
}
# Optimized predictions
has_threshold = (
artifacts.get("threshold_result") is not None
or artifacts.get("optimal_thresholds_ovr") is not None
)
if has_threshold:
if is_binary:
threshold = artifacts["optimal_threshold"]
y_pred_opt = (effective_proba >= threshold).astype(int)
else:
thresholds = artifacts["optimal_thresholds_ovr"]
above = cal_proba >= thresholds[np.newaxis, :]
n_above = above.sum(axis=1)
idx_opt = np.where(
n_above == 1,
np.argmax(above, axis=1),
np.argmax(cal_proba, axis=1),
)
y_pred_opt = self.classes_[idx_opt]
result["y_pred_optimized"] = y_pred_opt
result["scores_optimized"] = self._compute_metrics(
y_test, y_pred_opt, cal_proba, is_binary
)
result["confusion_matrix_optimized"] = confusion_matrix(y_test, y_pred_opt)
# Conformal prediction sets
if artifacts.get("conformal_result") is not None:
from nestkit.conformal.classifier_conformal import MondrianClassifierConformal
conformal_output = MondrianClassifierConformal.predict(
probas=cal_proba,
conformal_result=artifacts["conformal_result"],
classes=self.classes_,
)
result["conformal_prediction_sets"] = conformal_output["prediction_sets"]
result["conformal_set_sizes"] = conformal_output["set_sizes"]
result["conformal_coverage"] = float(
np.mean(
[
y_test[i] in conformal_output["prediction_sets"][i]
for i in range(len(y_test))
]
)
)
return result
def _apply_calibration(self, raw_proba: np.ndarray, artifacts: dict) -> np.ndarray:
"""Apply calibration to raw probabilities."""
if artifacts["calibrator"] is not None:
return artifacts["calibrator"].predict_proba(extract_positive_proba(raw_proba))
if artifacts.get("calibrators_ovr") is not None:
cal_proba = np.zeros_like(raw_proba)
for c, cal_c in enumerate(artifacts["calibrators_ovr"]):
cal_proba[:, c] = cal_c.predict_proba(raw_proba[:, c])[:, 1]
row_sums = cal_proba.sum(axis=1, keepdims=True)
return cal_proba / (row_sums + _EPS)
return raw_proba
def _compute_metrics(self, y_true, y_pred, y_proba, is_binary: bool) -> dict[str, float]:
"""Compute classification metrics."""
metrics = {
"accuracy": accuracy_score(y_true, y_pred),
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
}
if is_binary:
metrics["precision"] = precision_score(y_true, y_pred, zero_division=0.0)
metrics["recall"] = recall_score(y_true, y_pred, zero_division=0.0)
metrics["f1"] = f1_score(y_true, y_pred, zero_division=0.0)
try:
metrics["roc_auc"] = roc_auc_score(y_true, extract_positive_proba(y_proba))
except ValueError:
metrics["roc_auc"] = float("nan")
else:
avg = "macro"
metrics["precision"] = precision_score(y_true, y_pred, average=avg, zero_division=0.0)
metrics["recall"] = recall_score(y_true, y_pred, average=avg, zero_division=0.0)
metrics["f1"] = f1_score(y_true, y_pred, average=avg, zero_division=0.0)
try:
metrics["roc_auc"] = roc_auc_score(y_true, y_proba, multi_class="ovr", average=avg)
except ValueError:
metrics["roc_auc"] = float("nan")
return metrics
def _resolve_criterion(self):
"""Resolve threshold criterion to callable."""
if callable(self.threshold_criterion):
return self.threshold_criterion
mapping = {
"youden": lambda: youden_j,
"f_beta": lambda: f_beta_criterion(self.threshold_beta),
"cost": lambda: cost_sensitive(self.cost_matrix),
"balanced_accuracy": lambda: balanced_accuracy_criterion,
"precision_at_recall": lambda: precision_at_recall(self.min_recall),
}
return mapping[self.threshold_criterion]()
def _build_fold_result(self, **kwargs) -> ClassifierOuterFoldResult:
artifacts = kwargs.pop("artifacts")
eval_result = kwargs.pop("eval_result")
return ClassifierOuterFoldResult(
fold_idx=kwargs["fold_idx"],
train_indices=kwargs["train_idx"],
test_indices=kwargs["test_idx"],
best_params=kwargs["best_params"],
best_inner_score=kwargs["best_inner_score"],
inner_cv_results=kwargs["inner_cv_results"],
fit_time=kwargs["fit_time"],
score_time=kwargs["score_time"],
fitted_estimator=kwargs["estimator"],
y_true=eval_result["y_true"],
y_proba_raw=eval_result["y_proba_raw"],
y_pred_default=eval_result["y_pred_default"],
outer_scores_default=eval_result["scores_default"],
confusion_matrix_default=eval_result["confusion_matrix_default"],
y_proba_calibrated=eval_result["y_proba_calibrated"],
calibration_method=self.calibration_method,
calibrator=artifacts.get("calibrator"),
oof_calibration_diagnostics=eval_result.get("calibration_diagnostics"),
y_pred_optimized=eval_result["y_pred_optimized"],
outer_scores_optimized=eval_result["scores_optimized"],
confusion_matrix_optimized=eval_result["confusion_matrix_optimized"],
threshold_result=artifacts.get("threshold_result"),
conformal_result=artifacts.get("conformal_result"),
conformal_prediction_sets=eval_result.get("conformal_prediction_sets"),
conformal_set_sizes=eval_result.get("conformal_set_sizes"),
conformal_coverage=eval_result.get("conformal_coverage"),
)