SHAP-Style Plots

Beyond waterfall plots, SHAP values can be visualized as summary (beeswarm) plots and dependence plots.

[ ]:
import skexplain
from skexplain.common.importance_utils import to_skexplain_importance
import plotting_config
import numpy as np
import shap
[ ]:
# Load data and models
estimators = skexplain.load_models()
X, y = skexplain.load_data()

# Create a random subset of 1000 examples
random_state = np.random.RandomState(42)
N = 1000
ind = random_state.choice(len(X), size=N, replace=False)
X_subset = X.iloc[ind]
y_subset = y[ind]
X_subset.reset_index(inplace=True, drop=True)
[ ]:
explainer = skexplain.ExplainToolkit(estimators[0], X=X_subset)
explainer.set_plotting_config(
    display_feature_names=plotting_config.display_feature_names,
    display_units=plotting_config.display_units,
    feature_colors=plotting_config.color_dict,
)

Computing SHAP Values

[ ]:
shap_kws = {
    'masker': shap.maskers.Partition(X, max_samples=100, clustering='correlation'),
    'algorithm': 'permutation',
}

# Compute SHAP values for the subset
# To save time, you can load pre-computed results:
#   results = explainer.load('../tutorial_data/attr_values.nc')
results = explainer.local_attributions(
    method='shap',
    shap_kws=shap_kws,
    n_jobs=8,
)

Summary (Beeswarm) Plot

[ ]:
explainer.scatter_plot(
    plot_type='summary',
    dataset=results,
    method='shap',
    estimator_name='Random Forest',
)

Dependence Plots

Show how SHAP values relate to feature values.

[ ]:
explainer.scatter_plot(
    features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],
    plot_type='dependence',
    dataset=results,
    method=['shap'],
    estimator_name='Random Forest',
    interaction_index=None,
)

Color-Coded Dependence

Color by interaction feature to reveal interactions.

[ ]:
explainer.scatter_plot(
    features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],
    plot_type='dependence',
    dataset=results,
    method=['shap'],
    estimator_name='Random Forest',
    interaction_index='auto',
    colorbar_pad=0.2,
    wspace=0.7,
    hspace=0.6,
)

Converting Attributions to Importance Rankings

[ ]:
# Convert SHAP values to importance scores for the bar-style ranking plot
shap_rank = to_skexplain_importance(
    results['shap_values__Random Forest'].values,
    estimator_name='Random Forest',
    feature_names=X.columns,
    method='shap_sum',
)

explainer.plot_importance(
    data=[shap_rank],
    panels=[('shap_sum', 'Random Forest')],
    xlabels=['SHAP'],
)