Source code for sklearn_utilities.proba.transformed_target_estimator

from __future__ import annotations

import warnings
from typing import Any, Callable, Generic

from pandas import Series
from typing_extensions import Self

from ..estimator_wrapper import EstimatorWrapperBase
from ..id_transformer import IdTransformer
from ..types import TX, TY, TEstimator, TTransformer


def _parse_2d(func: Callable[..., TY]) -> Callable[..., TY]:
    def wrapper(X: TX, **kwargs: Any) -> TY:
        ndim = X.ndim
        if ndim == 1:
            if isinstance(X, Series):
                X = X.to_frame()
            else:
                X = X[:, None]
        res = func(X, **kwargs)
        if ndim == 1:
            try:
                return res.squeeze(axis=1)
            except ValueError as e:
                warnings.warn(f"Failed to squeeze: {e}")
        return res

    return wrapper


[docs] class TransformedTargetEstimatorVar( EstimatorWrapperBase[TEstimator], Generic[TEstimator, TTransformer] ): """TransformTargetRegressor with std/var support.""" def __init__( self, estimator: TEstimator, *, transformer: TTransformer = IdTransformer(), inverse_transform_separately: bool = False, ) -> None: super().__init__(estimator) self.transformer = transformer self.inverse_transform_separately = inverse_transform_separately
[docs] def fit(self, X: TX, y: TY, **fit_params: Any) -> Self: y = _parse_2d(self.transformer.fit_transform)(y) self.estimator.fit(X, y, **fit_params) return self
[docs] def predict(self, X: TX, **predict_params: Any) -> TY | tuple[TY, TY]: if predict_params.get("return_std", False): pred, pred_std = self.estimator.predict(X, **predict_params) if self.inverse_transform_separately: return _parse_2d(self.transformer.inverse_transform)(pred), _parse_2d( self.transformer.inverse_transform )(pred_std, return_std=True) else: return _parse_2d(self.transformer.inverse_transform)( (pred, pred_std), return_std=True ) pred = self.estimator.predict(X, **predict_params) return _parse_2d(self.transformer.inverse_transform)(pred)
[docs] def predict_var(self, X: TX, **predict_params: Any) -> TY: pred_var = self.estimator.predict_var(X, **predict_params) return _parse_2d(self.transformer.inverse_transform_var)(pred_var)