Source code for sklearn_utilities.drop_missing_rows_y

from __future__ import annotations

from typing import Any, Generic

from pandas import DataFrame
from typing_extensions import Self

from .estimator_wrapper import EstimatorWrapperBase
from .types import TEstimator


[docs] class DropMissingRowsY(EstimatorWrapperBase[TEstimator], Generic[TEstimator]): """A wrapper for estimators that drops NaN values from y before fitting.""" def __init__(self, estimator: TEstimator) -> None: """A wrapper for estimators that drops NaN values from y before fitting. Parameters ---------- estimator : TEstimator The estimator to be wrapped. """ super().__init__(estimator)
[docs] def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self: y = y.dropna() idx = X.index.intersection(y.index) self.estimator.fit(X.loc[idx], y.loc[idx], **fit_params) return self