Source code for sklearn_utilities.intersect

from typing import Any, Generic

from typing_extensions import Self

from .estimator_wrapper import EstimatorWrapperBase
from .types import TEstimator
from .utils import intersect_X_y


[docs] class IntersectXY(EstimatorWrapperBase[TEstimator], Generic[TEstimator]): """Estimator wrapper that intersects X and y indices before fitting.""" def __init__(self, estimator: TEstimator) -> None: """Estimator wrapper that intersects X and y indices before fitting. Parameters ---------- estimator : TEstimator The estimator to be wrapped. """ super().__init__(estimator)
[docs] def fit(self, X: Any, y: Any, **fit_params: Any) -> Self: X, y = intersect_X_y(X, y) self.estimator.fit(X, y, **fit_params) return self