Source code for skexplain.main.explain_toolkit
import numpy as np
import pandas as pd
from ..common.attributes import Attributes
from .local_explainer import LocalExplainer
from .global_explainer import GlobalExplainer
from ..common.utils import is_str, is_list, is_tuple
from ..plot.config import PlotConfig
from ._importance_mixin import ImportanceMixin
from ._curves_mixin import CurvesMixin
from ._interaction_mixin import InteractionMixin
from ._attribution_mixin import AttributionMixin
from ._plotting_mixin import PlottingMixin
from ._io_mixin import IOMixin
[docs]
class ExplainToolkit(
Attributes,
ImportanceMixin,
CurvesMixin,
InteractionMixin,
AttributionMixin,
PlottingMixin,
IOMixin,
):
"""
ExplainToolkit is the primary interface of scikit-explain. The modules contained within compute several
explainability machine learning methods such as
Feature importance:
* `permutation_importance`
* `ale_variance`
Feature Attributions:
- `ale`
- `pd`
- `ice`
- `local_attributions`
Feature Interactions:
- `interaction_strength`
- `ale_variance`
- `perm_based_interaction`
- `friedman_h_stat`
- `main_effect_complexity`
- `ale`
- `pd`
Additionally, there are corresponding plotting modules for
each method, which are designed to produce publication-quality graphics.
.. note::
ExplainToolkit is designed to work with estimators that implement predict or predict_proba.
.. caution::
ExplainToolkit is only designed to work with binary classification and regression problems.
In future versions of skexplain, we hope to be compatible with multi-class classification.
Parameters
-----------
estimators : list of tuples of (estimator name, fitted estimator)
Tuple of (estimator name, fitted estimator object) or list thereof where the
fitted estimator must implement ``predict`` or ``predict_proba``.
Multioutput-multiclass classifiers are not supported.
X : {array-like or dataframe} of shape (n_samples, n_features)
Training or validation data used to compute the IML methods.
If numpy.ndarray, must specify `feature_names`.
y : {list or numpy.array} of shape (n_samples,)
The target values (class labels in classification, real numbers in regression).
estimator_output : ``"raw"`` or ``"probability"``
What output of the estimator should be explained. Determined internally by
ExplainToolkit. However, if using a classification model, the user
can set to "raw" for non-probabilistic output.
feature_names : array-like of shape (n_features,), dtype=str, default=None
Name of each feature; ``feature_names[i]`` holds the name of the feature
with index ``i``. By default, the name of the feature corresponds to their numerical
index for NumPy array and their column name for pandas dataframe.
Feature names are only required if ``X`` is a numpy.ndarray, as it will be
converted to a pandas.DataFrame internally.
seaborn_kws : dict, None, or False (default is None)
Arguments for the seaborn.set_theme(). By default, we use the following settings.
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)
If False, then seaborn settings are not used.
Raises
---------
AssertError
Number of estimator objects is not equal to the number of estimator names given!
TypeError
y variable must be numpy array or pandas.DataFrame.
Exception
Feature names must be specified if X is a numpy.ndarray.
ValueError
estimator_output is not an accepted option.
"""
def __init__(
self,
estimators=None,
X=None,
y=None,
estimator_output=None,
feature_names=None,
seaborn_kws=None,
):
if X is None:
X = pd.DataFrame(np.array([]))
if y is None:
y = np.array([])
self.seaborn_kws = seaborn_kws
if estimators is not None:
if not is_list(estimators) and estimators:
estimators = [estimators]
# Check that the estimator name is provided!
for e in estimators:
if not is_tuple(e):
raise TypeError(
"The estimators arg must be a tuple of (estimator_name, estimator)!"
)
else:
if not is_str(e[0]):
raise TypeError(
"Estimator name is supposed to be a string. Make sure that the tuple is (estimator_name, estimator)."
)
estimator_names = [e[0] for e in estimators]
estimators = [e[1] for e in estimators]
else:
estimator_names = None
self.set_estimator_attribute(estimators, estimator_names)
self.set_y_attribute(y)
self.set_X_attribute(X, feature_names)
self.set_estimator_output(estimator_output, estimators)
self.checked_attributes = True
# Initialize a global interpret object
self.global_obj = GlobalExplainer(
estimators=self.estimators,
estimator_names=self.estimator_names,
X=self.X,
y=self.y,
estimator_output=self.estimator_output,
checked_attributes=self.checked_attributes,
)
# Initialize a local interpret object
self.local_obj = LocalExplainer(
estimators=self.estimators,
estimator_names=self.estimator_names,
X=self.X,
y=self.y,
estimator_output=self.estimator_output,
checked_attributes=self.checked_attributes,
)
self.attrs_dict = {
"estimator_output": self.estimator_output,
"estimators used": self.estimator_names,
}
# Initialize plot configuration
self._plot_config = PlotConfig()
if seaborn_kws is not None and isinstance(seaborn_kws, dict):
# Map legacy seaborn_kws to PlotConfig fields
if "style" in seaborn_kws:
self._plot_config.style = seaborn_kws["style"]
if "palette" in seaborn_kws:
self._plot_config.palette = seaborn_kws["palette"]
if "font_scale" in seaborn_kws:
self._plot_config.font_scale = seaborn_kws["font_scale"]
if "rc" in seaborn_kws:
self._plot_config.rc = seaborn_kws["rc"]
def __repr__(self):
return (
"ExplainToolkit(estimator=%s \n \
estimator_names=%s \n \
X=%s length:%d \n \
y=%s length:%d \n \
estimator_output=%s \n \
feature_names=%s length %d)"
% (
self.estimators,
self.estimator_names,
type(self.X),
len(self.X),
type(self.y),
len(self.y),
self.estimator_output,
type(self.feature_names),
len(self.feature_names),
)
)
[docs]
def set_plotting_config(self, **kwargs):
"""Set plot configuration. Works like seaborn.set_theme().
Parameters set here become defaults for all subsequent plot calls.
Per-call arguments override these defaults.
Parameters
----------
**kwargs
Any attribute of :class:`~skexplain.plot.config.PlotConfig`.
See ``PlotConfig`` for the full list.
Examples
--------
>>> explainer.set_plotting_config(
... figsize=(12, 8),
... display_feature_names={"temp2m": "Temperature (2m)"},
... display_units={"temp2m": "K"},
... style="whitegrid",
... base_font_size=14,
... )
>>> # All subsequent plots use these settings:
>>> explainer.plot_ale(ale_data)
"""
valid_fields = PlotConfig.field_names()
for key, val in kwargs.items():
if key not in valid_fields:
raise ValueError(
f"Unknown config key '{key}'. "
f"Available: {valid_fields}"
)
setattr(self._plot_config, key, val)
# Update seaborn_kws for backward compatibility with plot classes
self.seaborn_kws = self._plot_config.to_seaborn_kws()
[docs]
def get_plotting_config(self):
"""Return the current plot configuration.
Returns
-------
PlotConfig
The current plot configuration dataclass.
"""
return self._plot_config
[docs]
def reset_plotting_config(self):
"""Reset all plot configuration to defaults.
Restores the default PlotConfig (no custom display names,
units, colors, or seaborn overrides). Equivalent to creating
a fresh ExplainToolkit with no seaborn_kws.
Examples
--------
>>> explainer.set_plotting_config(style="whitegrid", base_font_size=16)
>>> # ... do some plotting ...
>>> explainer.reset_plotting_config() # back to defaults
"""
self._plot_config = PlotConfig()
self.seaborn_kws = None
def _append_attributes(self, ds):
"""
FOR INTERNAL PURPOSES ONLY.
Append attributes to a xarray.Dataset or pandas.DataFrame
Parameters
----------
ds : xarray.Dataset or pandas.DataFrame
Results data from the IML methods
"""
for key in self.attrs_dict.keys():
ds.attrs[key] = self.attrs_dict[key]
return ds