Source code for sklearn_utilities.torch.skorch.reshaper

from __future__ import annotations

import warnings
from typing import Any, Generic, TypeVar

import numpy as np
import torch
from numpy.lib.stride_tricks import sliding_window_view
from skorch import NeuralNet
from typing_extensions import Self

from sklearn_utilities.estimator_wrapper import EstimatorWrapperBase
from sklearn_utilities.types import TX, TY

TEstimator = TypeVar("TEstimator", bound=NeuralNet)


[docs] class SkorchReshaper(EstimatorWrapperBase[TEstimator], Generic[TEstimator]): """skorch wrapper that reshapes tabular data for NNs.""" def __init__(self, estimator: TEstimator) -> None: """skorch wrapper that reshapes tabular data for NNs. X: [B, F] -> [B, F] y: [B] -> [B, 1] or [B, NY] -> [B, NY] where B: batch, F: features, NY: number of outputs Parameters ---------- estimator : TEstimator The estimator to wrap. """ super().__init__(estimator) if self.estimator.device == "cpu": warnings.warn( "You are using a CPU for training, which may be very slow. " f"Consider using a GPU. CUDA Availability: {torch.cuda.is_available()}", UserWarning, )
[docs] def fit(self, X: TX, y: TY, **fit_params: Any) -> Self: # allow multioutput X_, y_ = self._validate_data( X, y, validate_separately=( {"force_all_finite": False, "allow_nd": True, "ensure_2d": False}, {"force_all_finite": False, "allow_nd": True, "ensure_2d": False}, ), ) X_ = X_.astype(np.float32) y_ = y_.astype(np.float32) if y_.ndim == 1: y_ = np.expand_dims(y_, axis=1) self.y_ndim_ = y_.ndim self.estimator.fit(X_, y_, **fit_params) return self
[docs] def predict(self, X: TX, **predict_params: Any) -> TY: X_: np.ndarray = self._validate_data( X, force_all_finite=False, allow_nd=True, ensure_2d=False ) X_ = X_.astype(np.float32) y = self.estimator.predict(X_, **predict_params) if self.y_ndim_ == 1 and y.shape[1] == 1: if y.shape[1] == 1: y = y.squeeze(axis=1) return y
[docs] class SkorchCNNReshaper(EstimatorWrapperBase[TEstimator], Generic[TEstimator]): """skorch wrapper that reshapes tabular data for CNNs using sliding windows.""" def __init__(self, estimator: TEstimator, *, window_size: int | None) -> None: """skorch wrapper that reshapes tabular data for CNNs using sliding windows. X: [B, F] -> [B - H + 1, 1, H, F] if window_size is not None (for Conv2d) [B, F] -> [B, 1, F] if window_size is None (for Conv1d) y: [B] -> [B - H + 1, 1] or [B, NY] -> [B - H + 1, NY] where C = 1: channels, B: batch, H: window, F: features, NY: number of outputs Parameters ---------- estimator : TEstimator The estimator to wrap. window_size : int | None The size of the sliding window. Make sure that CNN kernel size is equal or larger than this. If None, no sliding window is applied. """ self.estimator = estimator self.window_size = window_size if self.estimator.device == "cpu": warnings.warn( "You are using a CPU for training, which may be very slow. " f"Consider using a GPU. CUDA availability: {torch.cuda.is_available()}", UserWarning, )
[docs] def fit(self, X: TX, y: TY, **fit_params: Any) -> Self: X_, y_ = self._validate_data( X, y, validate_separately=( {"force_all_finite": False, "allow_nd": True, "ensure_2d": False}, {"force_all_finite": False, "allow_nd": True, "ensure_2d": False}, ), ) X_ = X_.astype(np.float32) y_ = y_.astype(np.float32) if self.window_size is not None: X_ = sliding_window_view(X_, self.window_size, axis=0) if X_.shape[0] != y_.shape[0] - self.window_size + 1: raise AssertionError( f"X.shape[0] = {X_.shape[0]} " "!= y.shape[0] - self.window_size + 1 = " f"{y_.shape[0] - self.window_size + 1}" ) y_ = y_[self.window_size - 1 :] X_ = np.expand_dims(X_, axis=1) self.y_ndim_ = y_.ndim if y_.ndim == 1: y_ = np.expand_dims(y_, axis=1) self.estimator.fit(X_, y_, **fit_params) return self
[docs] def predict(self, X: TX, **predict_params: Any) -> TY: X_ = self._validate_data( X, force_all_finite=False, allow_nd=True, ensure_2d=False ) X_ = X_.astype(np.float32) if self.window_size is not None: X_ = sliding_window_view(X_, self.window_size, axis=0) X_ = np.expand_dims(X_, axis=1) y = self.estimator.predict(X_, **predict_params) if self.y_ndim_ == 1: if y.shape[1] == 1: y = y.squeeze(axis=1) return y