Multiclass Classification
scikit-explain supports multiclass problems for permutation importance, ALE, and SHAP.
[ ]:
import skexplain
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np
import shap
[ ]:
# Load the Iris dataset and train a logistic regression model
X, y = load_iris(return_X_y=True, as_frame=True)
lr = LogisticRegression().fit(X, y)
explainer = skexplain.ExplainToolkit(('LogisticRegression', lr), X=X, y=y)
Permutation Importance
[ ]:
results = explainer.permutation_importance(
n_vars=X.shape[1],
evaluation_fn='rpss',
scoring_strategy='minimize',
n_permute=5,
subsample=1.0,
n_jobs=X.shape[1],
verbose=True,
random_seed=42,
direction='backward',
)
[ ]:
fig = explainer.plot_importance(
data=results,
panels=[('backward_singlepass', 'LogisticRegression')],
num_vars_to_plot=4,
)
ALE Curves per Class
[ ]:
# Compute ALE for each class
ales = [
explainer.ale(features='all', n_bootstrap=1, n_jobs=4, n_bins=20, class_index=class_idx)
for class_idx in np.unique(y)
]
[ ]:
from skexplain.plot.base_plotting import PlotStructure
import seaborn as sns
features = X.columns
plotter = PlotStructure(BASE_FONT_SIZE=16)
fig, axes = plotter.create_subplots(
n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,
wspace=0.4, hspace=0.35,
)
colors = list(sns.color_palette('Set2'))
for ax, feature in zip(axes.flat, features):
for i, ale in enumerate(ales):
explainer.plot_ale(
ale=ale, features=feature, ax=ax,
line_kws={'line_colors': [colors[i]], 'linewidth': 2.0},
)
plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])
SHAP Values per Class
[ ]:
# Compute SHAP values for each class
shap_results = []
for class_idx in np.unique(y):
shap_kws = {
'masker': shap.maskers.Partition(X, max_samples=10, clustering='correlation'),
'algorithm': 'permutation',
'class_idx': class_idx,
}
shap_results.append(
explainer.local_attributions(method='shap', shap_kws=shap_kws)
)
[ ]:
plotter = PlotStructure(BASE_FONT_SIZE=16)
fig, axes = plotter.create_subplots(
n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,
wspace=0.4, hspace=0.35,
)
colors = list(sns.color_palette('Set2'))
for ax, feature in zip(axes.flat, features):
for i, shap_vals in enumerate(shap_results):
explainer.scatter_plot(
features=[feature],
plot_type='dependence',
dataset=shap_vals,
method=['shap'],
estimator_name='LogisticRegression',
color=colors[i],
interaction_index=None,
ax=ax,
)
plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])
Each color represents a different target class. The ALE and SHAP dependence plots show how each feature contributes differently depending on the class being predicted.