import warnings
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import matplotlib.pyplot as plt
import seaborn as sns
from typing_extensions import Literal
from etna.analysis.feature_relevance.relevance import RelevanceTable
from etna.analysis.feature_relevance.relevance import StatisticsRelevanceTable
from etna.analysis.feature_relevance.utils import _get_fictitious_relevances
from etna.analysis.feature_selection import AGGREGATION_FN
from etna.analysis.feature_selection import AggregationMode
from etna.analysis.utils import _prepare_axes
if TYPE_CHECKING:
from etna.datasets import TSDataset
[docs]def plot_feature_relevance(
ts: "TSDataset",
relevance_table: RelevanceTable,
normalized: bool = False,
relevance_aggregation_mode: Union[str, Literal["per-segment"]] = AggregationMode.mean,
relevance_params: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
alpha: float = 0.05,
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
"""
Plot relevance of the features.
The most important features are at the top, the least important are at the bottom.
For :py:class:`~etna.analysis.feature_relevance.relevance.StatisticsRelevanceTable` also plot vertical line: transformed significance level.
* Values that lie to the right of this line have p-value < alpha.
* And the values that lie to the left have p-value > alpha.
Parameters
----------
ts:
TSDataset with timeseries data
relevance_table:
method to evaluate the feature relevance;
* if :py:class:`~etna.analysis.feature_relevance.relevance.StatisticsRelevanceTable` table is used then relevances are normalized p-values
* if :py:class:`~etna.analysis.feature_relevance.relevance.ModelRelevanceTable` table is used then relevances are importances from some model
normalized:
whether obtained relevances should be normalized to sum up to 1
relevance_aggregation_mode:
aggregation strategy for obtained feature relevance table;
all the strategies can be examined
at :py:class:`~etna.analysis.feature_selection.mrmr_selection.AggregationMode`
relevance_params:
additional keyword arguments for the ``__call__`` method of
:py:class:`~etna.analysis.feature_relevance.relevance.RelevanceTable`
top_k:
number of best features to plot, if None plot all the features
alpha:
significance level, default alpha = 0.05, only for :py:class:`~etna.analysis.feature_relevance.relevance.StatisticsRelevanceTable`
segments:
segments to use
columns_num:
if ``relevance_aggregation_mode="per-segment"`` number of columns in subplots, otherwise the value is ignored
figsize:
size of the figure per subplot with one segment in inches
"""
if relevance_params is None:
relevance_params = {}
if segments is None:
segments = sorted(ts.segments)
border_value = None
features = list(set(ts.columns.get_level_values("feature")) - {"target"})
relevance_df = relevance_table(df=ts[:, segments, "target"], df_exog=ts[:, segments, features], **relevance_params)
if relevance_aggregation_mode == "per-segment":
_, ax = _prepare_axes(num_plots=len(segments), columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
relevance = relevance_df.loc[segment]
if isinstance(relevance_table, StatisticsRelevanceTable):
relevance, border_value = _get_fictitious_relevances(
relevance,
alpha,
)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(
f"Relevances on segment: {segment} of features: {na_relevance_features} can't be calculated."
)
relevance = relevance.sort_values(ascending=False)
relevance = relevance.dropna()[:top_k]
if normalized:
if border_value is not None:
border_value = border_value / relevance.sum()
relevance = relevance / relevance.sum()
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax[i])
if border_value is not None:
ax[i].axvline(border_value)
ax[i].set_title(f"Feature relevance: {segment}")
else:
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
relevance = relevance_df.apply(lambda x: relevance_aggregation_fn(x[~x.isna()])) # type: ignore
if isinstance(relevance_table, StatisticsRelevanceTable):
relevance, border_value = _get_fictitious_relevances(
relevance,
alpha,
)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(f"Relevances of features: {na_relevance_features} can't be calculated.")
# if top_k == None, all the values are selected
relevance = relevance.sort_values(ascending=False)
relevance = relevance.dropna()[:top_k]
if normalized:
if border_value is not None:
border_value = border_value / relevance.sum()
relevance = relevance / relevance.sum()
_, ax = plt.subplots(figsize=figsize, constrained_layout=True)
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax)
if border_value is not None:
ax.axvline(border_value) # type: ignore
ax.set_title("Feature relevance") # type: ignore
ax.grid() # type: ignore