aiice.preprocess

  1from typing import Sequence
  2
  3import torch
  4from torch.utils.data import Dataset
  5
  6
  7def apply_threshold(tensor: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
  8    "Binarize tensor with a threshold"
  9    return (tensor > threshold).to(tensor.dtype)
 10
 11
 12def apply_downsample(
 13    t: torch.Tensor, i: int, axes: tuple[int, ...] = (-1,)
 14) -> torch.Tensor:
 15    """
 16    Downsample a tensor by keeping every i-th element along specified axes.
 17
 18    Args:
 19        t (`torch.Tensor`): Input tensor.
 20        i (`int`): Step for downsampling. Must be greater than 0.
 21        axes (`tuple[int]`): Axes along which to downsample. Negative axes are supported.
 22    """
 23    if i <= 0:
 24        raise ValueError("i must be > 0")
 25
 26    out = t
 27    for axis in axes:
 28        axis = axis if axis >= 0 else t.dim() + axis
 29
 30        idx = torch.arange(out.shape[axis], device=out.device)
 31        keep = idx % i == 0
 32        out = torch.index_select(out, axis, idx[keep])
 33
 34    return out
 35
 36
 37class SlidingWindowDataset(Dataset):
 38    """
 39    Convert a time series into (X, Y) pairs using sliding windows.
 40
 41    X represents past observations of length `pre_history_len`,
 42    Y represents future observations of length `forecast_len`.
 43
 44    ![image](../../.doc/media/sliding-window.png)
 45
 46    The dataset is generated lazily: windows are sliced on demand from the
 47    original tensor without materializing the full dataset in memory.
 48    The time dimension is assumed to be the first axis of the input tensor.
 49
 50    Args:
 51        data (`Sequence`): Time series data of shape `[T, ...]` where `T` is the time dimension
 52            and remaining dimensions represent features or channels.
 53        pre_history_len (`int`): Number of time steps in each input window (X).
 54        forecast_len (`int`): Number of time steps in each output window (Y).
 55        idx (`Sequence`, optional): Optional sequence of any indeces corresponding
 56            to each time step in `data`. Must have the same length as the time dimension `T`.
 57            If provided, `__getitem__` returns a tuple `(id, X, Y)` containing the
 58            corresponding timestamps for the selected window, otherwise it returns only `(X, Y)`.
 59        threshold (`float`, optional): If provided, binarizes the target tensor Y using this threshold.
 60            Values strictly greater than the threshold are set to 1, and values less than or equal to
 61            the threshold are set to 0. Defaults to None.
 62        x_binarize (`bool`, optional): If True and `threshold` is provided, applies the same binarization
 63            to the input tensor X. Defaults to False.
 64        device (`str`, optional): Device on which to place the tensors (e.g., "cpu", "cuda"). Defaults to None.
 65        dtype (torch.dtype, optional): Data type used to convert the input sequence. Defaults to torch.float32.
 66    """
 67
 68    def __init__(
 69        self,
 70        data: Sequence,
 71        pre_history_len: int,
 72        forecast_len: int,
 73        idx: Sequence | None = None,
 74        threshold: float | None = None,
 75        x_binarize: bool = False,
 76        device: str | None = None,
 77        dtype: torch.dtype = torch.float32,
 78    ):
 79        self._data = torch.as_tensor(data, dtype=dtype, device=device)
 80        self._indices = idx
 81
 82        self._threshold = threshold
 83        self._x_binarize = x_binarize
 84
 85        if self._data.ndim == 1:
 86            self._data = self._data.unsqueeze(-1)  # [T] -> [T, 1]
 87
 88        self._pre_history_len = pre_history_len
 89        self._forecast_len = forecast_len
 90
 91        self._T = self._data.shape[0]
 92        if self._indices is not None and self._T != len(self._indices):
 93            raise ValueError(
 94                f"Data length (got {self._T}) should be equal to indices length (got {len(self._indices)})"
 95            )
 96
 97        self._length = self._T - pre_history_len - forecast_len + 1
 98
 99        if self._length <= 0:
100            raise ValueError(
101                f"Not enough data: got {self._T}, need at least {pre_history_len + forecast_len}"
102            )
103
104    def __len__(self):
105        return self._length
106
107    def __getitem__(self, idx: int):
108        if not isinstance(idx, int):
109            raise TypeError("index must be int")
110
111        if idx < 0 or idx >= self._length:
112            raise IndexError("index out of range")
113
114        x = self._data[idx : idx + self._pre_history_len]
115        y = self._data[
116            idx
117            + self._pre_history_len : idx
118            + self._pre_history_len
119            + self._forecast_len
120        ]
121
122        if isinstance(self._threshold, float):
123            y = apply_threshold(y, self._threshold)
124            x = apply_threshold(x, self._threshold) if self._x_binarize else x
125
126        if self._indices is not None:
127            idx_slice = self._indices[
128                idx : idx + self._pre_history_len + self._forecast_len
129            ]
130            return idx_slice, x, y
131
132        return x, y
def apply_threshold(tensor: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
 8def apply_threshold(tensor: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
 9    "Binarize tensor with a threshold"
10    return (tensor > threshold).to(tensor.dtype)

Binarize tensor with a threshold

def apply_downsample(t: torch.Tensor, i: int, axes: tuple[int, ...] = (-1,)) -> torch.Tensor:
13def apply_downsample(
14    t: torch.Tensor, i: int, axes: tuple[int, ...] = (-1,)
15) -> torch.Tensor:
16    """
17    Downsample a tensor by keeping every i-th element along specified axes.
18
19    Args:
20        t (`torch.Tensor`): Input tensor.
21        i (`int`): Step for downsampling. Must be greater than 0.
22        axes (`tuple[int]`): Axes along which to downsample. Negative axes are supported.
23    """
24    if i <= 0:
25        raise ValueError("i must be > 0")
26
27    out = t
28    for axis in axes:
29        axis = axis if axis >= 0 else t.dim() + axis
30
31        idx = torch.arange(out.shape[axis], device=out.device)
32        keep = idx % i == 0
33        out = torch.index_select(out, axis, idx[keep])
34
35    return out

Downsample a tensor by keeping every i-th element along specified axes.

Arguments:
  • t (torch.Tensor): Input tensor.
  • i (int): Step for downsampling. Must be greater than 0.
  • axes (tuple[int]): Axes along which to downsample. Negative axes are supported.
class SlidingWindowDataset(typing.Generic[+_T_co]):
 38class SlidingWindowDataset(Dataset):
 39    """
 40    Convert a time series into (X, Y) pairs using sliding windows.
 41
 42    X represents past observations of length `pre_history_len`,
 43    Y represents future observations of length `forecast_len`.
 44
 45    ![image](../../.doc/media/sliding-window.png)
 46
 47    The dataset is generated lazily: windows are sliced on demand from the
 48    original tensor without materializing the full dataset in memory.
 49    The time dimension is assumed to be the first axis of the input tensor.
 50
 51    Args:
 52        data (`Sequence`): Time series data of shape `[T, ...]` where `T` is the time dimension
 53            and remaining dimensions represent features or channels.
 54        pre_history_len (`int`): Number of time steps in each input window (X).
 55        forecast_len (`int`): Number of time steps in each output window (Y).
 56        idx (`Sequence`, optional): Optional sequence of any indeces corresponding
 57            to each time step in `data`. Must have the same length as the time dimension `T`.
 58            If provided, `__getitem__` returns a tuple `(id, X, Y)` containing the
 59            corresponding timestamps for the selected window, otherwise it returns only `(X, Y)`.
 60        threshold (`float`, optional): If provided, binarizes the target tensor Y using this threshold.
 61            Values strictly greater than the threshold are set to 1, and values less than or equal to
 62            the threshold are set to 0. Defaults to None.
 63        x_binarize (`bool`, optional): If True and `threshold` is provided, applies the same binarization
 64            to the input tensor X. Defaults to False.
 65        device (`str`, optional): Device on which to place the tensors (e.g., "cpu", "cuda"). Defaults to None.
 66        dtype (torch.dtype, optional): Data type used to convert the input sequence. Defaults to torch.float32.
 67    """
 68
 69    def __init__(
 70        self,
 71        data: Sequence,
 72        pre_history_len: int,
 73        forecast_len: int,
 74        idx: Sequence | None = None,
 75        threshold: float | None = None,
 76        x_binarize: bool = False,
 77        device: str | None = None,
 78        dtype: torch.dtype = torch.float32,
 79    ):
 80        self._data = torch.as_tensor(data, dtype=dtype, device=device)
 81        self._indices = idx
 82
 83        self._threshold = threshold
 84        self._x_binarize = x_binarize
 85
 86        if self._data.ndim == 1:
 87            self._data = self._data.unsqueeze(-1)  # [T] -> [T, 1]
 88
 89        self._pre_history_len = pre_history_len
 90        self._forecast_len = forecast_len
 91
 92        self._T = self._data.shape[0]
 93        if self._indices is not None and self._T != len(self._indices):
 94            raise ValueError(
 95                f"Data length (got {self._T}) should be equal to indices length (got {len(self._indices)})"
 96            )
 97
 98        self._length = self._T - pre_history_len - forecast_len + 1
 99
100        if self._length <= 0:
101            raise ValueError(
102                f"Not enough data: got {self._T}, need at least {pre_history_len + forecast_len}"
103            )
104
105    def __len__(self):
106        return self._length
107
108    def __getitem__(self, idx: int):
109        if not isinstance(idx, int):
110            raise TypeError("index must be int")
111
112        if idx < 0 or idx >= self._length:
113            raise IndexError("index out of range")
114
115        x = self._data[idx : idx + self._pre_history_len]
116        y = self._data[
117            idx
118            + self._pre_history_len : idx
119            + self._pre_history_len
120            + self._forecast_len
121        ]
122
123        if isinstance(self._threshold, float):
124            y = apply_threshold(y, self._threshold)
125            x = apply_threshold(x, self._threshold) if self._x_binarize else x
126
127        if self._indices is not None:
128            idx_slice = self._indices[
129                idx : idx + self._pre_history_len + self._forecast_len
130            ]
131            return idx_slice, x, y
132
133        return x, y

Convert a time series into (X, Y) pairs using sliding windows.

X represents past observations of length pre_history_len, Y represents future observations of length forecast_len.

image

The dataset is generated lazily: windows are sliced on demand from the original tensor without materializing the full dataset in memory. The time dimension is assumed to be the first axis of the input tensor.

Arguments:
  • data (Sequence): Time series data of shape [T, ...] where T is the time dimension and remaining dimensions represent features or channels.
  • pre_history_len (int): Number of time steps in each input window (X).
  • forecast_len (int): Number of time steps in each output window (Y).
  • idx (Sequence, optional): Optional sequence of any indeces corresponding to each time step in data. Must have the same length as the time dimension T. If provided, __getitem__ returns a tuple (id, X, Y) containing the corresponding timestamps for the selected window, otherwise it returns only (X, Y).
  • threshold (float, optional): If provided, binarizes the target tensor Y using this threshold. Values strictly greater than the threshold are set to 1, and values less than or equal to the threshold are set to 0. Defaults to None.
  • x_binarize (bool, optional): If True and threshold is provided, applies the same binarization to the input tensor X. Defaults to False.
  • device (str, optional): Device on which to place the tensors (e.g., "cpu", "cuda"). Defaults to None.
  • dtype (torch.dtype, optional): Data type used to convert the input sequence. Defaults to torch.float32.