Multiclass Classification

scikit-explain supports multiclass problems for permutation importance, ALE, and SHAP.

[ ]:
import skexplain
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np
import shap
[ ]:
# Load the Iris dataset and train a logistic regression model
X, y = load_iris(return_X_y=True, as_frame=True)
lr = LogisticRegression().fit(X, y)

explainer = skexplain.ExplainToolkit(('LogisticRegression', lr), X=X, y=y)

Permutation Importance

[ ]:
results = explainer.permutation_importance(
    n_vars=X.shape[1],
    evaluation_fn='rpss',
    scoring_strategy='minimize',
    n_permute=5,
    subsample=1.0,
    n_jobs=X.shape[1],
    verbose=True,
    random_seed=42,
    direction='backward',
)
[ ]:
fig = explainer.plot_importance(
    data=results,
    panels=[('backward_singlepass', 'LogisticRegression')],
    num_vars_to_plot=4,
)

ALE Curves per Class

[ ]:
# Compute ALE for each class
ales = [
    explainer.ale(features='all', n_bootstrap=1, n_jobs=4, n_bins=20, class_index=class_idx)
    for class_idx in np.unique(y)
]
[ ]:
from skexplain.plot.base_plotting import PlotStructure
import seaborn as sns

features = X.columns
plotter = PlotStructure(BASE_FONT_SIZE=16)
fig, axes = plotter.create_subplots(
    n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,
    wspace=0.4, hspace=0.35,
)

colors = list(sns.color_palette('Set2'))
for ax, feature in zip(axes.flat, features):
    for i, ale in enumerate(ales):
        explainer.plot_ale(
            ale=ale, features=feature, ax=ax,
            line_kws={'line_colors': [colors[i]], 'linewidth': 2.0},
        )

plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])

SHAP Values per Class

[ ]:
# Compute SHAP values for each class
shap_results = []
for class_idx in np.unique(y):
    shap_kws = {
        'masker': shap.maskers.Partition(X, max_samples=10, clustering='correlation'),
        'algorithm': 'permutation',
        'class_idx': class_idx,
    }
    shap_results.append(
        explainer.local_attributions(method='shap', shap_kws=shap_kws)
    )
[ ]:
plotter = PlotStructure(BASE_FONT_SIZE=16)
fig, axes = plotter.create_subplots(
    n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,
    wspace=0.4, hspace=0.35,
)

colors = list(sns.color_palette('Set2'))
for ax, feature in zip(axes.flat, features):
    for i, shap_vals in enumerate(shap_results):
        explainer.scatter_plot(
            features=[feature],
            plot_type='dependence',
            dataset=shap_vals,
            method=['shap'],
            estimator_name='LogisticRegression',
            color=colors[i],
            interaction_index=None,
            ax=ax,
        )

plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])

Each color represents a different target class. The ALE and SHAP dependence plots show how each feature contributes differently depending on the class being predicted.