"""Feature importance visualizations."""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from nestkit.plotting._style import _apply_axis_limits, _get_ax
if TYPE_CHECKING:
from matplotlib.axes import Axes
_UNIT = (0.0, 1.0)
[docs]
def plot_importance(
aggregator,
top_k: int = 20,
show_folds: bool = True,
bar_alpha: float = 0.7,
fold_color: str = "red",
fold_alpha: float = 0.5,
fold_size: float = 18,
ax=None,
**kwargs,
) -> Axes:
"""Bar plot of mean feature importance with optional per-fold jitter.
Parameters
----------
aggregator : FeatureImportanceAggregator
Fitted importance aggregator with per-fold importance data.
top_k : int, optional
Number of top features to display.
show_folds : bool, optional
Whether to overlay per-fold importance values.
bar_alpha : float, optional
Opacity of the bars.
fold_color : str, optional
Color of per-fold scatter markers.
fold_alpha : float, optional
Opacity of per-fold scatter markers.
fold_size : float, optional
Size of per-fold scatter markers.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
df = aggregator.summary_.head(top_k)
y = np.arange(len(df))
ax.barh(y, df["mean_importance"], xerr=df["std_importance"], alpha=bar_alpha)
if show_folds:
features = df["feature"].values
names = aggregator.feature_names or [
f"feature_{i}" for i in range(aggregator.importances_matrix_.shape[1])
]
for i, feat in enumerate(features):
feat_idx = names.index(feat) if feat in names else i
fold_vals = aggregator.importances_matrix_[:, feat_idx]
ax.scatter(
fold_vals,
np.full(len(fold_vals), i),
alpha=fold_alpha,
s=fold_size,
color=fold_color,
)
ax.set_yticks(y)
ax.set_yticklabels(df["feature"])
ax.set_xlabel("Importance")
ax.set_title(f"Feature Importance (top {top_k})")
ax.invert_yaxis()
return ax
[docs]
def plot_rank_stability_features(
aggregator,
top_k: int = 20,
cmap: str = "YlOrRd",
label_fontsize: int = 7,
ax=None,
**kwargs,
) -> Axes:
"""Feature rank stability heatmap across folds.
Parameters
----------
aggregator : FeatureImportanceAggregator
Fitted importance aggregator with per-fold rank data.
top_k : int, optional
Number of top features to display.
cmap : str, optional
Colormap for the heatmap.
label_fontsize : int, optional
Font size for y-axis feature labels.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
import matplotlib.pyplot as plt
df = aggregator.summary_.head(top_k)
features = df["feature"].values
names = aggregator.feature_names or [
f"feature_{i}" for i in range(aggregator.ranks_matrix_.shape[1])
]
rank_data = []
for feat in features:
feat_idx = names.index(feat) if feat in names else 0
rank_data.append(aggregator.ranks_matrix_[:, feat_idx])
rank_data = np.array(rank_data)
im = ax.imshow(rank_data, aspect="auto", cmap=cmap)
ax.set_yticks(range(len(features)))
ax.set_yticklabels(features, fontsize=label_fontsize)
ax.set_xlabel("Fold")
ax.set_title("Feature Rank Stability")
plt.colorbar(im, ax=ax, label="Rank")
return ax
[docs]
def plot_shap_summary(aggregator, top_k: int = 20, ax=None, **kwargs) -> Axes:
"""Pooled SHAP beeswarm plot.
Concatenates raw SHAP values from all outer test folds and renders
a beeswarm summary plot using the ``shap`` package. Requires the
aggregator to have been created with ``method='shap'``.
Parameters
----------
aggregator : FeatureImportanceAggregator
Fitted importance aggregator with raw SHAP values stored.
top_k : int, optional
Maximum number of features to display.
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
Raises
------
ValueError
If raw SHAP values are not available in the aggregator.
"""
ax = _get_ax(ax)
if not aggregator.raw_importances_:
raise ValueError("SHAP raw values not available. Use method='shap'.")
try:
import shap
pooled = np.concatenate(aggregator.raw_importances_)
shap.summary_plot(pooled, max_display=top_k, show=False)
except ImportError:
ax.text(0.5, 0.5, "shap package required", ha="center", va="center")
return ax
[docs]
def plot_selection_frequency(
aggregator,
top_k: int = 10,
bar_alpha: float = 0.7,
label_fontsize: int = 7,
full_range: bool = False,
xlim: tuple[float, float] | None = None,
ax=None,
**kwargs,
) -> Axes:
"""Feature selection frequency across folds.
Parameters
----------
aggregator : FeatureImportanceAggregator
Fitted importance aggregator with per-fold importance data.
top_k : int, optional
The top-*k* threshold used to count selection frequency.
bar_alpha : float, optional
Opacity of the bars.
label_fontsize : int, optional
Font size for y-axis feature labels.
full_range : bool, optional
If ``True``, set x-axis to [0, 1].
xlim : tuple of float or None, optional
Explicit x-axis limits (override *full_range*).
ax : matplotlib.axes.Axes or None, optional
Axes to plot on. If ``None``, a new figure is created.
**kwargs
Additional keyword arguments passed to the underlying matplotlib call.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
ax = _get_ax(ax)
n_folds, n_features = aggregator.importances_matrix_.shape
names = aggregator.feature_names or [f"feature_{i}" for i in range(n_features)]
frequency = np.zeros(n_features)
for i in range(n_folds):
top_idx = np.argsort(-aggregator.importances_matrix_[i])[:top_k]
frequency[top_idx] += 1
frequency /= n_folds
sorted_idx = np.argsort(-frequency)[: top_k * 2]
y = np.arange(len(sorted_idx))
ax.barh(y, frequency[sorted_idx], alpha=bar_alpha)
ax.set_yticks(y)
ax.set_yticklabels([names[i] for i in sorted_idx], fontsize=label_fontsize)
ax.set_xlabel(f"Frequency in top-{top_k}")
ax.set_title("Feature Selection Frequency")
ax.invert_yaxis()
_apply_axis_limits(ax, xlim=xlim, full_range=full_range, natural_xlim=_UNIT)
return ax