{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Plot Configuration\n", "\n", "scikit-explain provides a **seaborn-style configuration API** that lets you set plotting\n", "defaults once and have them apply to all subsequent plot calls. This avoids passing\n", "`display_feature_names`, `display_units`, and other options to every method.\n", "\n", "The key methods are:\n", "- `explainer.set_plotting_config(**kwargs)` — set defaults\n", "- `explainer.get_plotting_config()` — inspect current config\n", "- `explainer.reset_plotting_config()` — restore defaults" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import skexplain\n", "import plotting_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimators = skexplain.load_models()\n", "X, y = skexplain.load_data()\n", "explainer = skexplain.ExplainToolkit(estimators=estimators, X=X, y=y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting Plot Configuration\n", "\n", "Call `set_plotting_config()` once after creating the ExplainToolkit.\n", "All subsequent plot calls will use these settings as defaults." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.set_plotting_config(\n", " display_feature_names=plotting_config.display_feature_names,\n", " display_units=plotting_config.display_units,\n", " feature_colors=plotting_config.color_dict,\n", " style='ticks',\n", " base_font_size=12,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now plot calls are clean — no need to pass display names or units every time:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ale = explainer.ale(features=['sfc_temp', 'temp2m', 'dwpt2m', 'sat_irbt'], n_bins=20)\n", "explainer.plot_ale(ale=ale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "perm_imp = explainer.permutation_importance(n_vars=10, evaluation_fn='norm_aupdc')\n", "explainer.plot_importance(\n", " data=perm_imp,\n", " panels=[('multipass', 'Random Forest')],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Per-Call Overrides\n", "\n", "You can still pass arguments directly to any plot method.\n", "Per-call arguments **always override** the config defaults." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Override display names for just this one plot\n", "explainer.plot_ale(\n", " ale=ale,\n", " display_feature_names={'sfc_temp': 'Surface Temp (custom)'},\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting Configuration\n", "\n", "Use `get_plotting_config()` to see the current settings:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = explainer.get_plotting_config()\n", "print(config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Changing the Seaborn Theme\n", "\n", "The `style`, `palette`, `font_scale`, and `rc` options control the seaborn theme,\n", "similar to `seaborn.set_theme()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.set_plotting_config(\n", " style='whitegrid',\n", " base_font_size=16,\n", ")\n", "explainer.plot_ale(ale=ale)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Resetting to Defaults\n", "\n", "Call `reset_plotting_config()` to clear all custom settings and return to\n", "the default matplotlib/seaborn appearance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.reset_plotting_config()\n", "\n", "# Now plots use default feature names (no pretty labels)\n", "explainer.plot_ale(ale=ale)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Available Configuration Options\n", "\n", "| Option | Type | Description |\n", "|--------|------|-------------|\n", "| `display_feature_names` | dict | Map feature names to display names |\n", "| `display_units` | dict | Map feature names to unit strings |\n", "| `feature_colors` | dict | Map feature names to colors |\n", "| `figsize` | tuple | Default figure size (width, height) |\n", "| `n_columns` | int | Columns in multi-panel layouts |\n", "| `wspace` / `hspace` | float | Subplot spacing |\n", "| `base_font_size` | int | Base font size |\n", "| `style` | str | Seaborn style ('ticks', 'whitegrid', etc.) |\n", "| `palette` | str | Seaborn color palette |\n", "| `font_scale` | float | Font scaling factor |\n", "| `rc` | dict | Matplotlib rcParams overrides |\n", "| `add_hist` | bool | Add histograms behind ALE/PD curves |\n", "| `to_probability` | bool | Display as percentages |\n", "| `num_vars_to_plot` | int | Default number of features to plot |" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 4 }