SHAP-Style Plots
Beyond waterfall plots, SHAP values can be visualized as summary (beeswarm) plots and dependence plots.
[ ]:
import skexplain
from skexplain.common.importance_utils import to_skexplain_importance
import plotting_config
import numpy as np
import shap
[ ]:
# Load data and models
estimators = skexplain.load_models()
X, y = skexplain.load_data()
# Create a random subset of 1000 examples
random_state = np.random.RandomState(42)
N = 1000
ind = random_state.choice(len(X), size=N, replace=False)
X_subset = X.iloc[ind]
y_subset = y[ind]
X_subset.reset_index(inplace=True, drop=True)
[ ]:
explainer = skexplain.ExplainToolkit(estimators[0], X=X_subset)
explainer.set_plotting_config(
display_feature_names=plotting_config.display_feature_names,
display_units=plotting_config.display_units,
feature_colors=plotting_config.color_dict,
)
Computing SHAP Values
[ ]:
shap_kws = {
'masker': shap.maskers.Partition(X, max_samples=100, clustering='correlation'),
'algorithm': 'permutation',
}
# Compute SHAP values for the subset
# To save time, you can load pre-computed results:
# results = explainer.load('../tutorial_data/attr_values.nc')
results = explainer.local_attributions(
method='shap',
shap_kws=shap_kws,
n_jobs=8,
)
Summary (Beeswarm) Plot
[ ]:
explainer.scatter_plot(
plot_type='summary',
dataset=results,
method='shap',
estimator_name='Random Forest',
)
Dependence Plots
Show how SHAP values relate to feature values.
[ ]:
explainer.scatter_plot(
features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],
plot_type='dependence',
dataset=results,
method=['shap'],
estimator_name='Random Forest',
interaction_index=None,
)
Color-Coded Dependence
Color by interaction feature to reveal interactions.
[ ]:
explainer.scatter_plot(
features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],
plot_type='dependence',
dataset=results,
method=['shap'],
estimator_name='Random Forest',
interaction_index='auto',
colorbar_pad=0.2,
wspace=0.7,
hspace=0.6,
)
Converting Attributions to Importance Rankings
[ ]:
# Convert SHAP values to importance scores for the bar-style ranking plot
shap_rank = to_skexplain_importance(
results['shap_values__Random Forest'].values,
estimator_name='Random Forest',
feature_names=X.columns,
method='shap_sum',
)
explainer.plot_importance(
data=[shap_rank],
panels=[('shap_sum', 'Random Forest')],
xlabels=['SHAP'],
)