{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Feature Attributions\n", "Local attributions explain individual predictions by decomposing them into feature contributions.\n", "\n", "scikit-explain supports:\n", "- **SHAP** \u2014 SHapley Additive Explanations\n", "- **LIME** \u2014 Local Interpretable Model Explanations\n", "- **Tree Interpreter** \u2014 tree-based model decomposition" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "import plotting_config\n", "import shap" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Loading the training data and pre-fit models\n", "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Pick a single example for local attribution\n", "single_example = X.iloc[[0]]\n", "\n", "explainer = skexplain.ExplainToolkit(estimators[0], X=single_example)\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing Local Attributions" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# SHAP masker handles missing features using correlations in the dataset\n", "shap_kws = {\n", " 'masker': shap.maskers.Partition(X, max_samples=100, clustering='correlation'),\n", " 'algorithm': 'permutation',\n", "}\n", "\n", "# LIME requires the training data\n", "lime_kws = {\n", " 'training_data': X.values,\n", " 'categorical_names': ['rural', 'urban'],\n", "}\n", "\n", "# Compute all three attribution methods at once\n", "contrib_ds = explainer.local_attributions(\n", " method=['shap', 'lime', 'tree_interpreter'],\n", " shap_kws=shap_kws,\n", " lime_kws=lime_kws,\n", ")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# The result is an xarray.Dataset with attribution values for each method\n", "contrib_ds" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting Attributions (Waterfall Plot)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig, axes = explainer.plot_contributions(contrib=contrib_ds)" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Performance-Based Attributions\n", "Average attributions across the best and worst-performing examples." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Create a new ExplainToolkit with the full dataset\n", "explainer = skexplain.ExplainToolkit(estimators[0], X=X, y=y)\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", ")\n", "\n", "# Compute performance-based attributions using tree interpreter\n", "tree_results = explainer.average_attributions(\n", " method='tree_interpreter',\n", " performance_based=True,\n", " n_samples=100,\n", ")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Plot the top 5 features for selected performance categories\n", "perf_keys = ['Best Hits', 'Worst False Alarms', 'Worst Misses']\n", "\n", "fig, axes = explainer.plot_contributions(\n", " contrib=tree_results,\n", " perf_keys=perf_keys,\n", " num_features=5,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression Example" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "from sklearn.datasets import fetch_california_housing\n", "from sklearn.ensemble import RandomForestRegressor\n", "\n", "data = fetch_california_housing()\n", "X_reg = data['data']\n", "y_reg = data['target']\n", "feature_names = data['feature_names']\n", "\n", "model = RandomForestRegressor()\n", "model.fit(X_reg, y_reg)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "single_example = X_reg[[0]]\n", "\n", "explainer = skexplain.ExplainToolkit(\n", " ('Random Forest', model),\n", " X=single_example,\n", " feature_names=feature_names,\n", ")\n", "\n", "shap_kws = {\n", " 'masker': shap.maskers.Partition(X_reg, max_samples=100, clustering='correlation'),\n", " 'algorithm': 'auto',\n", "}\n", "\n", "results = explainer.local_attributions(method='shap', shap_kws=shap_kws)\n", "fig = explainer.plot_contributions(results, figsize=(4, 8))" ], "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }