Source code for skexplain.plot.config

"""Plot configuration for scikit-explain, inspired by seaborn's set_theme() API.

Users can set plot configuration once and have it apply to all subsequent
plot calls. Per-call arguments always override these defaults.

Examples
--------
>>> explainer.set_plotting_config(
...     figsize=(12, 8),
...     display_feature_names={"temp2m": "Temperature (2m)"},
...     display_units={"temp2m": "K"},
...     style="whitegrid",
...     base_font_size=14,
... )
>>> # All subsequent plots use these settings:
>>> explainer.plot_ale(ale_data)
"""

from dataclasses import dataclass, fields, field
from typing import Optional, Dict, Tuple


[docs] @dataclass class PlotConfig: """Global plot configuration. Set once on ExplainToolkit via ``set_plotting_config()``, applies to all subsequent plot calls. Per-call arguments override these defaults. Attributes ---------- figsize : tuple of (float, float), optional Default figure size in inches (width, height). n_columns : int, optional Number of columns in multi-panel layouts. wspace : float, optional Width spacing between subplots. hspace : float, optional Height spacing between subplots. base_font_size : int, default=12 Base font size for all text elements. display_feature_names : dict, optional Maps internal feature names to display-friendly names. E.g., ``{"dwpt2m": "$T_{d}$"}``. display_units : dict, optional Maps feature names to unit strings. E.g., ``{"dwpt2m": "$^\\circ$C"}``. feature_colors : dict, optional Maps feature names to colors for color-coding groups. style : str, default="ticks" Seaborn style passed to ``sns.set_theme(style=...)``. Options: "darkgrid", "whitegrid", "dark", "white", "ticks". palette : str, optional Seaborn color palette. font_scale : float, default=1.0 Font scaling factor. rc : dict, optional Matplotlib rcParams overrides passed to ``sns.set_theme(rc=...)``. add_hist : bool, default=True Whether to add background histograms on ALE/PD plots. to_probability : bool, optional If True, multiply values by 100 for probability display. num_vars_to_plot : int, default=10 Default number of top features to plot in importance plots. """ # Layout figsize: Optional[Tuple[float, float]] = None n_columns: Optional[int] = None wspace: Optional[float] = None hspace: Optional[float] = None # Typography base_font_size: int = 12 # Feature display display_feature_names: Optional[Dict[str, str]] = None display_units: Optional[Dict[str, str]] = None feature_colors: Optional[Dict[str, str]] = None # Seaborn theme style: str = "ticks" palette: Optional[str] = None font_scale: float = 1.0 rc: Optional[dict] = None # Plot-specific defaults add_hist: bool = True to_probability: Optional[bool] = None num_vars_to_plot: int = 10
[docs] def to_seaborn_kws(self): """Convert theme settings to seaborn_kws dict for PlotStructure.""" kws = {"style": self.style} if self.palette is not None: kws["palette"] = self.palette if self.font_scale != 1.0: kws["font_scale"] = self.font_scale rc = {"axes.spines.right": False, "axes.spines.top": False} if self.rc is not None: rc.update(self.rc) kws["rc"] = rc return kws
[docs] @classmethod def field_names(cls): """Return list of valid config field names.""" return [f.name for f in fields(cls)]