"""Fold callback protocol and built-in callback implementations.
Defines the :class:`FoldCallback` runtime-checkable protocol that all
callbacks must satisfy, together with three ready-made implementations:
* :class:`ProgressCallback` -- tqdm progress bar.
* :class:`LoggingCallback` -- structured ``logging``-based messages.
* :class:`CheckpointCallback` -- pickle intermediate fold results.
Custom callbacks need only implement the five hook methods specified by
:class:`FoldCallback`.
"""
from __future__ import annotations
import logging
import pickle
import time
from pathlib import Path
from typing import Any, Protocol, runtime_checkable
import numpy as np
logger = logging.getLogger("nestkit")
[docs]
@runtime_checkable
class FoldCallback(Protocol):
"""Runtime-checkable protocol for nested CV fold callbacks.
Any object that implements the five hook methods below can be used
as a callback. The nested CV engine calls these hooks at well-
defined points during execution.
Methods
-------
on_outer_fold_start(fold_idx, train_idx, test_idx)
Called before inner search begins for an outer fold.
on_inner_search_complete(fold_idx, search)
Called after the inner hyperparameter search finishes.
on_post_processing_complete(fold_idx, artifacts)
Called after any post-processing (e.g., threshold tuning).
on_outer_fold_complete(fold_idx, result)
Called after the outer fold evaluation is complete.
on_nested_cv_complete(results)
Called once after all outer folds have been processed.
Examples
--------
>>> class MyCallback: # doctest: +SKIP
... def on_outer_fold_start(self, fold_idx, train_idx, test_idx):
... print(f"Starting fold {fold_idx}")
... def on_inner_search_complete(self, fold_idx, search): ...
... def on_post_processing_complete(self, fold_idx, artifacts): ...
... def on_outer_fold_complete(self, fold_idx, result): ...
... def on_nested_cv_complete(self, results): ...
See Also
--------
ProgressCallback, LoggingCallback, CheckpointCallback
"""
[docs]
def on_outer_fold_start(
self, fold_idx: int, train_idx: np.ndarray, test_idx: np.ndarray
) -> None: ...
[docs]
def on_inner_search_complete(self, fold_idx: int, search: Any) -> None: ...
[docs]
def on_post_processing_complete(self, fold_idx: int, artifacts: dict) -> None: ...
[docs]
def on_outer_fold_complete(self, fold_idx: int, result: Any) -> None: ...
[docs]
def on_nested_cv_complete(self, results: Any) -> None: ...
[docs]
class ProgressCallback:
"""Display a tqdm progress bar during nested cross-validation.
The progress bar is created lazily on the first
``on_outer_fold_start`` call and advances by one step after each
outer fold completes. If ``tqdm`` is not installed the callback
silently does nothing.
Parameters
----------
n_outer_folds : int or None, optional
Total number of outer folds. Passed as the ``total`` argument
to ``tqdm``. If ``None``, the progress bar will have
indeterminate length.
Examples
--------
>>> cb = ProgressCallback(n_outer_folds=5) # doctest: +SKIP
See Also
--------
LoggingCallback : Text-based logging alternative.
"""
def __init__(self, n_outer_folds: int | None = None):
self._n_folds = n_outer_folds
self._pbar = None
[docs]
def on_outer_fold_start(self, fold_idx, train_idx, test_idx):
if self._pbar is None:
try:
from tqdm.auto import tqdm
self._pbar = tqdm(total=self._n_folds, desc="Outer folds")
except ImportError:
pass
[docs]
def on_inner_search_complete(self, fold_idx, search):
pass
[docs]
def on_post_processing_complete(self, fold_idx, artifacts):
pass
[docs]
def on_outer_fold_complete(self, fold_idx, result):
if self._pbar is not None:
self._pbar.update(1)
[docs]
def on_nested_cv_complete(self, results):
if self._pbar is not None:
self._pbar.close()
[docs]
class CheckpointCallback:
"""Pickle intermediate fold results to disk after each outer fold.
After every outer fold, the fold result is saved as
``fold_<idx>.pkl`` inside the given directory. When the full
nested CV completes, the final results object is saved as
``final_results.pkl``.
Parameters
----------
path : str or pathlib.Path
Directory in which checkpoint files are written. Created
automatically (including parents) if it does not exist.
Attributes
----------
path : pathlib.Path
Resolved checkpoint directory.
Examples
--------
>>> cb = CheckpointCallback("/tmp/ncv_checkpoints") # doctest: +SKIP
"""
def __init__(self, path: str | Path):
self.path = Path(path)
self.path.mkdir(parents=True, exist_ok=True)
[docs]
def on_outer_fold_start(self, fold_idx, train_idx, test_idx):
pass
[docs]
def on_inner_search_complete(self, fold_idx, search):
pass
[docs]
def on_post_processing_complete(self, fold_idx, artifacts):
pass
[docs]
def on_outer_fold_complete(self, fold_idx, result):
filepath = self.path / f"fold_{fold_idx}.pkl"
with open(filepath, "wb") as f:
pickle.dump(result, f)
logger.info("Checkpointed fold %d to %s", fold_idx, filepath)
[docs]
def on_nested_cv_complete(self, results):
filepath = self.path / "final_results.pkl"
with open(filepath, "wb") as f:
pickle.dump(results, f)
logger.info("Checkpointed final results to %s", filepath)
[docs]
class LoggingCallback:
"""Emit structured log messages at each nested CV lifecycle event.
Logs fold start (with train/test sizes), inner search completion
(with best parameters and score), post-processing completion, fold
completion (with elapsed time), and overall completion.
Parameters
----------
level : int, default=logging.INFO
Python logging level for all emitted messages.
Attributes
----------
level : int
Logging level.
_fold_start_times : dict[int, float]
Mapping from fold index to wall-clock start time, used to
compute elapsed seconds.
Examples
--------
>>> import logging
>>> cb = LoggingCallback(level=logging.DEBUG) # doctest: +SKIP
See Also
--------
ProgressCallback : Visual progress bar alternative.
"""
def __init__(self, level: int = logging.INFO):
self.level = level
self._fold_start_times: dict[int, float] = {}
[docs]
def on_outer_fold_start(self, fold_idx, train_idx, test_idx):
self._fold_start_times[fold_idx] = time.time()
logger.log(
self.level,
"Fold %d: started (train=%d, test=%d)",
fold_idx,
len(train_idx),
len(test_idx),
)
[docs]
def on_inner_search_complete(self, fold_idx, search):
logger.log(
self.level,
"Fold %d: inner search complete, best_params=%s, best_score=%.4f",
fold_idx,
search.best_params_,
search.best_score_,
)
[docs]
def on_post_processing_complete(self, fold_idx, artifacts):
logger.log(self.level, "Fold %d: post-processing complete", fold_idx)
[docs]
def on_outer_fold_complete(self, fold_idx, result):
elapsed = time.time() - self._fold_start_times.get(fold_idx, time.time())
logger.log(self.level, "Fold %d: complete (%.1fs)", fold_idx, elapsed)
[docs]
def on_nested_cv_complete(self, results):
logger.log(self.level, "Nested CV complete: %d folds", results.n_outer_folds_)