{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Individual Conditional Expectations (ICE)\n", "ICE curves show how each individual example's prediction changes with a feature. When ICE curves diverge, it indicates feature interactions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "import plotting_config" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load the training data and pre-fit models\n", "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()\n", "X = X.astype({'urban': 'category', 'rural': 'category'})\n", "\n", "# Use the Random Forest model only\n", "explainer = skexplain.ExplainToolkit(estimators[0], X=X, y=y)\n", "\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", " feature_colors=plotting_config.color_dict,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing ICE Curves\n", "\n", "ICE curves are computed with `explainer.ice()`. The `subsample` argument controls how many individual curves are drawn. Plotting more than 200 curves makes the plot hard to read." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "features = ['sfc_temp', 'temp2m', 'dwpt2m', 'sfcT_hrs_bl_frez']\n", "\n", "ice_ds = explainer.ice(\n", " features=features,\n", " subsample=200,\n", " n_jobs=4,\n", " n_bins=20,\n", ")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Compute ALE for the same features (used as the overlay mean line)\n", "ale_1d_ds = explainer.ale(\n", " features=features,\n", " n_bootstrap=1,\n", " subsample=10000,\n", " n_jobs=1,\n", " n_bins=20,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting ALE with ICE Overlay\n", "\n", "Pass the ICE dataset to `plot_ale` via the `ice_curves` argument. The bold line is the ALE (mean effect) and the thin lines are individual ICE curves. Spread in the ICE curves indicates feature interactions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig, axes = explainer.plot_ale(\n", " ale=ale_1d_ds,\n", " ice_curves=ice_ds,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Color-Coded ICE Curves\n", "\n", "Color-coding ICE curves by another feature's value reveals interactions between features. Use the `color_by` argument to specify which feature to color by." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "fig, axes = explainer.plot_ale(\n", " ale=ale_1d_ds,\n", " features=features,\n", " ice_curves=ice_ds,\n", " color_by='temp2m',\n", " figsize=(10, 6),\n", " wspace=0.25,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the colored ICE curves separate into distinct bands (e.g., warm colors trending differently than cool colors), it indicates that the feature being plotted interacts with the `color_by` feature. If all colored lines follow the same trend regardless of color, there is little interaction between the two features.\n", "\n", "ICE curves are a permutation-based method and assume feature independence, so correlated features can muddle the interpretation. Use them as an exploratory tool alongside ALE and 2D ALE for a more complete picture." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }