Source code for sklearn_utilities.proba.compose_var

from __future__ import annotations

from typing import Any, Generic, Literal, overload

from typing_extensions import Self

from ..estimator_wrapper import EstimatorWrapperBase
from ..types import TX, TY, TEstimator, TEstimatorVar
from .dummy_regressor import DummyRegressorVar


[docs] class ComposeVarEstimator( EstimatorWrapperBase[TEstimator], Generic[TEstimator, TEstimatorVar] ): """Compose an estimator with a variance estimator.""" def __init__( self, estimator: TEstimator, estimator_var: TEstimatorVar = DummyRegressorVar() ) -> None: """Compose an estimator with a variance estimator. Parameters ---------- estimator : TEstimator The estimator to be wrapped. estimator_var : TEstimatorVar, optional The variance estimator to be wrapped, by default DummyRegressorVar() """ super().__init__(estimator) self.estimator_var = estimator_var
[docs] def fit(self, X: TX, y: TY, **fit_params: Any) -> Self: self.estimator.fit(X, y, **fit_params) self.estimator_var.fit(X, y, **fit_params) return self
@overload def predict( self, X: TX, return_std: Literal[False] = ..., **predict_params: Any ) -> TY: ... @overload def predict( self, X: TX, return_std: Literal[True], **predict_params: Any ) -> tuple[TY, TY]: ...
[docs] def predict( self, X: TX, return_std: bool = False, **predict_params: Any ) -> TY | tuple[TY, TY]: if return_std: return ( self.estimator.predict(X, **predict_params), self.estimator_var.predict(X, return_std=True, **predict_params)[1], ) return self.estimator.predict(X, **predict_params)
[docs] def predict_var(self, X: TX, **predict_params: Any) -> TY: return self.estimator_var.predict_var(X, **predict_params)