{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Accumulated Local Effects (ALE)\n", "1D ALE curves show how features affect predictions, accounting for feature correlations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "import plotting_config" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load the training data and pre-fit models\n", "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()\n", "X = X.astype({'urban': 'category', 'rural': 'category'})\n", "\n", "explainer = skexplain.ExplainToolkit(estimators, X=X, y=y)\n", "\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing 1D ALE\n", "\n", "The `ale` method computes 1D ALE curves. Key arguments:\n", "- `features`: a single feature, list of features, or `'all'`\n", "- `n_bins`: number of percentile-based bins (default 30)\n", "- `n_bootstrap`: number of bootstrap iterations for confidence intervals\n", "- `subsample`: number of examples to use" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "ale_1d_ds = explainer.ale(\n", " features='all',\n", " n_bins=20,\n", " n_bootstrap=1,\n", " subsample=10000,\n", " n_jobs=1,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting ALE Curves\n", "\n", "Plot the ALE curve for a single feature. The light blue histogram in the background shows the data distribution." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig, ax = explainer.plot_ale(\n", " ale=ale_1d_ds,\n", " features='sfc_temp',\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting Multiple Features\n", "\n", "Use `get_important_vars` to select top features from a permutation importance result, then plot their ALE curves together." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load permutation importance results and get top features\n", "results = explainer.load(fnames='../tutorial_data/multipass_importance_naupdc.nc')\n", "important_vars = explainer.get_important_vars(\n", " results, multipass=True, n_vars=100, combine=True\n", ")\n", "\n", "fig, axes = explainer.plot_ale(\n", " ale=ale_1d_ds,\n", " features=important_vars,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizing ALE Plots\n", "\n", "You can customize line colors, styles, and the background histogram color using `line_kws` and `hist_color`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig, axes = explainer.plot_ale(\n", " ale=ale_1d_ds,\n", " features=important_vars,\n", " line_kws={\n", " 'line_colors': ['b', 'orange', 'k'],\n", " 'linewidth': 3.0,\n", " 'linestyle': 'dashed',\n", " },\n", " hist_color='red',\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Confidence Intervals via Bootstrapping\n", "\n", "Set `n_bootstrap` > 1 to compute confidence intervals on the ALE curves. The shaded bands represent the uncertainty in the mean ALE value across bootstrap samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "ale_1d_ci = explainer.ale(\n", " features=important_vars,\n", " n_bootstrap=10,\n", " subsample=1000,\n", " n_jobs=4,\n", " n_bins=10,\n", ")\n", "\n", "fig, axes = explainer.plot_ale(ale=ale_1d_ci)" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These confidence intervals reflect the uncertainty in the mean ALE value due to subsampling. They are **not** the same as the spread in individual conditional expectations (ICE curves), which capture variation from feature interactions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ALE for Regression\n", "\n", "ALE works for regression problems as well. Here we use the California housing dataset as a quick example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "from sklearn.datasets import fetch_california_housing\n", "from sklearn.ensemble import RandomForestRegressor\n", "\n", "data = fetch_california_housing()\n", "X_reg = data['data']\n", "y_reg = data['target']\n", "feature_names = data['feature_names']\n", "\n", "model = RandomForestRegressor()\n", "model.fit(X_reg, y_reg)\n", "\n", "explainer_reg = skexplain.ExplainToolkit(\n", " ('Random Forest', model), X=X_reg, y=y_reg, feature_names=feature_names\n", ")\n", "\n", "ale_reg = explainer_reg.ale(\n", " features=feature_names,\n", " n_bootstrap=1,\n", " subsample=10000,\n", " n_jobs=6,\n", " n_bins=30,\n", ")\n", "\n", "fig, axes = explainer_reg.plot_ale(ale_reg)" ], "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }