{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiclass Classification\n", "scikit-explain supports multiclass problems for permutation importance, ALE, and SHAP." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "from sklearn.datasets import load_iris\n", "from sklearn.linear_model import LogisticRegression\n", "import numpy as np\n", "import shap" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load the Iris dataset and train a logistic regression model\n", "X, y = load_iris(return_X_y=True, as_frame=True)\n", "lr = LogisticRegression().fit(X, y)\n", "\n", "explainer = skexplain.ExplainToolkit(('LogisticRegression', lr), X=X, y=y)" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Permutation Importance" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "results = explainer.permutation_importance(\n", " n_vars=X.shape[1],\n", " evaluation_fn='rpss',\n", " scoring_strategy='minimize',\n", " n_permute=5,\n", " subsample=1.0,\n", " n_jobs=X.shape[1],\n", " verbose=True,\n", " random_seed=42,\n", " direction='backward',\n", ")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig = explainer.plot_importance(\n", " data=results,\n", " panels=[('backward_singlepass', 'LogisticRegression')],\n", " num_vars_to_plot=4,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ALE Curves per Class" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Compute ALE for each class\n", "ales = [\n", " explainer.ale(features='all', n_bootstrap=1, n_jobs=4, n_bins=20, class_index=class_idx)\n", " for class_idx in np.unique(y)\n", "]" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "from skexplain.plot.base_plotting import PlotStructure\n", "import seaborn as sns\n", "\n", "features = X.columns\n", "plotter = PlotStructure(BASE_FONT_SIZE=16)\n", "fig, axes = plotter.create_subplots(\n", " n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,\n", " wspace=0.4, hspace=0.35,\n", ")\n", "\n", "colors = list(sns.color_palette('Set2'))\n", "for ax, feature in zip(axes.flat, features):\n", " for i, ale in enumerate(ales):\n", " explainer.plot_ale(\n", " ale=ale, features=feature, ax=ax,\n", " line_kws={'line_colors': [colors[i]], 'linewidth': 2.0},\n", " )\n", "\n", "plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SHAP Values per Class" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Compute SHAP values for each class\n", "shap_results = []\n", "for class_idx in np.unique(y):\n", " shap_kws = {\n", " 'masker': shap.maskers.Partition(X, max_samples=10, clustering='correlation'),\n", " 'algorithm': 'permutation',\n", " 'class_idx': class_idx,\n", " }\n", " shap_results.append(\n", " explainer.local_attributions(method='shap', shap_kws=shap_kws)\n", " )" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "plotter = PlotStructure(BASE_FONT_SIZE=16)\n", "fig, axes = plotter.create_subplots(\n", " n_panels=len(features), n_columns=2, figsize=(8, 8), dpi=300,\n", " wspace=0.4, hspace=0.35,\n", ")\n", "\n", "colors = list(sns.color_palette('Set2'))\n", "for ax, feature in zip(axes.flat, features):\n", " for i, shap_vals in enumerate(shap_results):\n", " explainer.scatter_plot(\n", " features=[feature],\n", " plot_type='dependence',\n", " dataset=shap_vals,\n", " method=['shap'],\n", " estimator_name='LogisticRegression',\n", " color=colors[i],\n", " interaction_index=None,\n", " ax=ax,\n", " )\n", "\n", "plotter.set_legend(len(features), fig, ax, labels=['Setosa', 'Versicolour', 'Virginica'])" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each color represents a different target class. The ALE and SHAP dependence plots\n", "show how each feature contributes differently depending on the class being predicted." ], "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }