Feature Interactions

scikit-explain provides several methods to detect and quantify feature interactions:

  • Friedman H-statistic

  • Interaction Strength (IAS)

  • Main Effect Complexity (MEC)

  • Sobol Indices

[ ]:
import skexplain
import plotting_config
[ ]:
# Load data and models
estimators = skexplain.load_models()
X, y = skexplain.load_data()

explainer = skexplain.ExplainToolkit(estimators[1], X=X, y=y)
explainer.set_plotting_config(
    display_feature_names=plotting_config.display_feature_names,
    display_units=plotting_config.display_units,
)

Friedman H-Statistic

Measures pairwise interaction strength between features using 1D and 2D ALE.

[ ]:
# Define feature pairs to test for interactions
features = [
    ('sfc_temp', 'temp2m'),
    ('sfc_temp', 'date_marker'),
    ('sfc_temp', 'sfcT_hrs_bl_frez'),
    ('sfc_temp', 'uplwav_flux'),
]

# Compute 1D ALE for all features and 2D ALE for the pairs
ale_1d_ds = explainer.ale(
    features='all', n_bootstrap=1, subsample=0.25, n_jobs=1, n_bins=20
)

ale_2d_ds = explainer.ale(
    features=features, n_bootstrap=1, subsample=0.25,
    n_jobs=len(features), n_bins=20,
)
[ ]:
# Compute the H-statistic
hstat_results = explainer.friedman_h_stat(ale_1d_ds, ale_2d_ds, features=features)
[ ]:
# Create display names for the feature pairs
adict = plotting_config.display_feature_names
display_feature_names = {
    f'{f[0]}__{f[1]}': f'{adict[f[1]]} & {adict[f[0]]}'
    for f in features
}

explainer.plot_importance(
    data=[hstat_results],
    panels=[('hstat', 'Gradient Boosting')],
    display_feature_names=display_feature_names,
)

Interaction Strength (IAS)

A global measure of how much a model relies on interactions vs main effects. Values near 0 indicate a purely additive model.

[ ]:
ias = explainer.interaction_strength(ale=ale_1d_ds, n_bootstrap=5, subsample=0.1)
[ ]:
ias

Low IAS values indicate the model is mainly composed of additive, first-order effects and feature interactions play a relatively minor role.

Main Effect Complexity (MEC)

Measures the complexity of each feature’s main effect curve. MEC=1 means linear, higher values mean more complex shapes.

[ ]:
mec = explainer.main_effect_complexity(ale=ale_1d_ds)

for var_name in mec.data_vars:
    print(f"{var_name}: {float(mec[var_name].values)}")

A MEC close to 1 indicates nearly linear main effects, while higher values indicate more complex, non-linear feature effects.

Sobol Indices

Decompose model variance into first-order effects and interaction effects.

[ ]:
# Use a fresh ExplainToolkit with the Random Forest for Sobol indices
explainer_rf = skexplain.ExplainToolkit(estimators[0], X=X, y=y)
explainer_rf.set_plotting_config(
    display_feature_names=plotting_config.display_feature_names,
    display_units=plotting_config.display_units,
    feature_colors=plotting_config.color_dict,
)

sobol_results = explainer_rf.sobol_indices(n_bootstrap=5000, class_index=1)
[ ]:
# Plot total, first-order, and interaction Sobol indices side by side
explainer_rf.plot_importance(
    data=[sobol_results, sobol_results, sobol_results],
    panels=[
        ('sobol_total', 'Random Forest'),
        ('sobol_1st', 'Random Forest'),
        ('sobol_interact', 'Random Forest'),
    ],
    figsize=(12, 4),
)

The total Sobol index captures the overall importance (first-order + interactions). The first-order index captures each feature’s independent contribution, while the interaction index captures the portion due to interactions with other features.