{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Interactions\n", "scikit-explain provides several methods to detect and quantify feature interactions:\n", "- **Friedman H-statistic**\n", "- **Interaction Strength (IAS)**\n", "- **Main Effect Complexity (MEC)**\n", "- **Sobol Indices**" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import skexplain\n", "import plotting_config" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load data and models\n", "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()\n", "\n", "explainer = skexplain.ExplainToolkit(estimators[1], X=X, y=y)\n", "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Friedman H-Statistic\n", "Measures pairwise interaction strength between features using 1D and 2D ALE." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Define feature pairs to test for interactions\n", "features = [\n", " ('sfc_temp', 'temp2m'),\n", " ('sfc_temp', 'date_marker'),\n", " ('sfc_temp', 'sfcT_hrs_bl_frez'),\n", " ('sfc_temp', 'uplwav_flux'),\n", "]\n", "\n", "# Compute 1D ALE for all features and 2D ALE for the pairs\n", "ale_1d_ds = explainer.ale(\n", " features='all', n_bootstrap=1, subsample=0.25, n_jobs=1, n_bins=20\n", ")\n", "\n", "ale_2d_ds = explainer.ale(\n", " features=features, n_bootstrap=1, subsample=0.25,\n", " n_jobs=len(features), n_bins=20,\n", ")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Compute the H-statistic\n", "hstat_results = explainer.friedman_h_stat(ale_1d_ds, ale_2d_ds, features=features)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Create display names for the feature pairs\n", "adict = plotting_config.display_feature_names\n", "display_feature_names = {\n", " f'{f[0]}__{f[1]}': f'{adict[f[1]]} & {adict[f[0]]}'\n", " for f in features\n", "}\n", "\n", "explainer.plot_importance(\n", " data=[hstat_results],\n", " panels=[('hstat', 'Gradient Boosting')],\n", " display_feature_names=display_feature_names,\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Interaction Strength (IAS)\n", "A global measure of how much a model relies on interactions vs main effects.\n", "Values near 0 indicate a purely additive model." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "ias = explainer.interaction_strength(ale=ale_1d_ds, n_bootstrap=5, subsample=0.1)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "ias" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Low IAS values indicate the model is mainly composed of additive, first-order effects\n", "and feature interactions play a relatively minor role." ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main Effect Complexity (MEC)\n", "Measures the complexity of each feature's main effect curve.\n", "MEC=1 means linear, higher values mean more complex shapes." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "mec = explainer.main_effect_complexity(ale=ale_1d_ds)\n", "\n", "for var_name in mec.data_vars:\n", " print(f\"{var_name}: {float(mec[var_name].values)}\")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A MEC close to 1 indicates nearly linear main effects, while higher values\n", "indicate more complex, non-linear feature effects." ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sobol Indices\n", "Decompose model variance into first-order effects and interaction effects." ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Use a fresh ExplainToolkit with the Random Forest for Sobol indices\n", "explainer_rf = skexplain.ExplainToolkit(estimators[0], X=X, y=y)\n", "explainer_rf.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", " feature_colors=plotting_config.color_dict,\n", ")\n", "\n", "sobol_results = explainer_rf.sobol_indices(n_bootstrap=5000, class_index=1)" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Plot total, first-order, and interaction Sobol indices side by side\n", "explainer_rf.plot_importance(\n", " data=[sobol_results, sobol_results, sobol_results],\n", " panels=[\n", " ('sobol_total', 'Random Forest'),\n", " ('sobol_1st', 'Random Forest'),\n", " ('sobol_interact', 'Random Forest'),\n", " ],\n", " figsize=(12, 4),\n", ")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The total Sobol index captures the overall importance (first-order + interactions).\n", "The first-order index captures each feature's independent contribution,\n", "while the interaction index captures the portion due to interactions with other features." ], "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }