{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SHAP-Style Plots\n", "Beyond waterfall plots, SHAP values can be visualized as summary (beeswarm) plots and dependence plots." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "from skexplain.common.importance_utils import to_skexplain_importance\n", "import plotting_config\n", "import numpy as np\n", "import shap" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load data and models\n", "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()\n", "\n", "# Create a random subset of 1000 examples\n", "random_state = np.random.RandomState(42)\n", "N = 1000\n", "ind = random_state.choice(len(X), size=N, replace=False)\n", "X_subset = X.iloc[ind]\n", "y_subset = y[ind]\n", "X_subset.reset_index(inplace=True, drop=True)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "explainer = skexplain.ExplainToolkit(estimators[0], X=X_subset)\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", " feature_colors=plotting_config.color_dict,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing SHAP Values" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "shap_kws = {\n", " 'masker': shap.maskers.Partition(X, max_samples=100, clustering='correlation'),\n", " 'algorithm': 'permutation',\n", "}\n", "\n", "# Compute SHAP values for the subset\n", "# To save time, you can load pre-computed results:\n", "# results = explainer.load('../tutorial_data/attr_values.nc')\n", "results = explainer.local_attributions(\n", " method='shap',\n", " shap_kws=shap_kws,\n", " n_jobs=8,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary (Beeswarm) Plot" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "explainer.scatter_plot(\n", " plot_type='summary',\n", " dataset=results,\n", " method='shap',\n", " estimator_name='Random Forest',\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dependence Plots\n", "Show how SHAP values relate to feature values." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "explainer.scatter_plot(\n", " features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],\n", " plot_type='dependence',\n", " dataset=results,\n", " method=['shap'],\n", " estimator_name='Random Forest',\n", " interaction_index=None,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Color-Coded Dependence\n", "Color by interaction feature to reveal interactions." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "explainer.scatter_plot(\n", " features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'],\n", " plot_type='dependence',\n", " dataset=results,\n", " method=['shap'],\n", " estimator_name='Random Forest',\n", " interaction_index='auto',\n", " colorbar_pad=0.2,\n", " wspace=0.7,\n", " hspace=0.6,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Converting Attributions to Importance Rankings" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Convert SHAP values to importance scores for the bar-style ranking plot\n", "shap_rank = to_skexplain_importance(\n", " results['shap_values__Random Forest'].values,\n", " estimator_name='Random Forest',\n", " feature_names=X.columns,\n", " method='shap_sum',\n", ")\n", "\n", "explainer.plot_importance(\n", " data=[shap_rank],\n", " panels=[('shap_sum', 'Random Forest')],\n", " xlabels=['SHAP'],\n", ")" ], "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }