from __future__ import annotations
import warnings
from typing import Any, Callable, Literal
from pandas import DataFrame, Index
from sklearn.base import BaseEstimator, TransformerMixin
from typing_extensions import Self
from .types import TXPandas
[docs]
class ReindexMissingColumns(BaseEstimator, TransformerMixin):
"""Reindex X to match the columns of the training data to avoid errors."""
def __init__(
self,
*,
if_missing: Literal["warn", "raise"]
| Callable[[Index[Any], Index[Any]], None] = "warn",
reindex_kwargs: dict[
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any
] = {},
) -> None:
"""Reindex X to match the columns of the training data to avoid errors.
Parameters
----------
if_missing : Literal['warn', 'raise'] | Callable[[Index[Any], Index[Any]], None], optional
If callable, the first argument is the expected columns and the
second argument is the actual columns, by default 'warn'
reindex_kwargs : dict[Literal['method', 'copy', 'level', 'fill_value',
'limit', 'tolerance'], Any], optional
Keyword arguments to pass to reindex, by default {}
"""
self.if_missing = if_missing
self.reindex_kwargs = reindex_kwargs
[docs]
def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self:
self.feature_names_in_ = X.columns
return self