aiice.loader

  1import csv
  2import functools
  3import io
  4from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  5from datetime import date, datetime
  6from io import BytesIO
  7from typing import TypeAlias
  8
  9import numpy as np
 10import torch
 11
 12from aiice.constants import (
 13    DATASET_SHAPE,
 14    MASK_SEA_DATA_MAX_VALUE,
 15    MASK_SEA_DATA_PATH,
 16    MASK_SEA_IDX_PATH,
 17    MASK_SEA_NAME_COLUMN,
 18    MASK_SEA_NAME_ID,
 19)
 20from aiice.core.huggingface import HfDatasetClient
 21from aiice.core.utils import get_date_from_filename_template
 22
 23NpWithIdx: TypeAlias = tuple[list[date], np.ndarray]
 24TorchWithIdx: TypeAlias = tuple[list[date], torch.Tensor]
 25
 26
 27class Loader:
 28    """
 29    Dataset Loader with a Hugging Face dataset client.
 30
 31    Downloading a large number of files in parallel may lead to
 32    request timeouts or temporary server-side errors from
 33    Hugging Face. If this happens, reduce the number of threads
 34    or split the download into smaller date ranges.
 35    """
 36
 37    def __init__(self):
 38        self._hf = HfDatasetClient()
 39
 40        sea_csv_reader = csv.DictReader(
 41            io.StringIO(self._get_raw_file(MASK_SEA_IDX_PATH).decode("utf-8"))
 42        )
 43        self._sea_map: dict[str, int] = {
 44            row[MASK_SEA_NAME_COLUMN]: int(row[MASK_SEA_NAME_ID])
 45            for row in sea_csv_reader
 46        }
 47
 48        self._sea_mask: np.ndarray = self._decode_raw_matrix(
 49            self._get_raw_file(MASK_SEA_DATA_PATH)
 50        )
 51        self._sea_mask[self._sea_mask == MASK_SEA_DATA_MAX_VALUE] = np.nan
 52
 53    @property
 54    def seas(self) -> tuple[str, ...]:
 55        """
 56        Return available seas.
 57        """
 58        return tuple(self._sea_map.keys())
 59
 60    @property
 61    def shape(self) -> tuple[int, ...]:
 62        """
 63        Return shape of a single dataset sample.
 64        """
 65        return self._hf.shape
 66
 67    @property
 68    def dataset_start(self) -> date:
 69        """
 70        Return earliest available date in the dataset.
 71        """
 72        return self._hf.dataset_start
 73
 74    @property
 75    def dataset_end(self) -> date:
 76        """
 77        Return latest available date in the dataset.
 78        """
 79        return self._hf.dataset_end
 80
 81    def info(self, per_year: bool = False) -> dict[str, any]:
 82        """
 83        Collect dataset statistics.
 84
 85        Args:
 86            per_year (bool): If True, include per-year statistics.
 87        """
 88        return self._hf.info(per_year=per_year)
 89
 90    def download(
 91        self,
 92        local_dir: str,
 93        start: date | str | None = None,
 94        end: date | str | None = None,
 95        step: int | str | None = None,
 96        threads: int = 16,
 97    ) -> list[str | None]:
 98        """
 99        Download dataset files to a local directory in parallel.
100        Raw numpy matrices in the dataset have range values from 0 to 100.
101
102        Args:
103            local_dir (`str`): Directory to save downloaded files.
104            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
105            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
106            step (`int` or `str`, optional): Step between files. If `int` - number of days.
107                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
108                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
109                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
110                Defaults to 1 day.
111            threads (`int`, optional): Number of parallel download threads. Defaults to 24.
112        """
113        start = self._convert_date(start)
114        end = self._convert_date(end)
115
116        filenames = self._hf.get_filenames(start=start, end=end, step=step)
117        with ThreadPoolExecutor(max_workers=threads) as pool:
118            return list(
119                pool.map(
120                    lambda f: self._hf.download_file(filename=f, local_dir=local_dir),
121                    filenames,
122                )
123            )
124
125    def get(
126        self,
127        start: date | str | None = None,
128        end: date | str | None = None,
129        step: int | str | None = None,
130        sea: str | None = None,
131        tensor_out: bool = False,
132        idx_out: bool = False,
133        threads: int = 16,
134        processes: int | None = None,
135    ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx:
136        """
137        Load dataset files into memory as numpy arrays or torch tensors.
138        Loaded matrices are normalized to float values in the range 0 to 1.
139
140        Args:
141            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
142            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
143            step (`int` or `str`, optional): Step between files. If `int` - number of days.
144                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
145                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
146                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
147                Defaults to 1 day.
148            sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones.
149            tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False.
150            idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False.
151            threads (`int`, optional): Number of parallel download threads. Defaults to 16.
152            processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count.
153        """
154        if sea is not None and sea not in self._sea_map:
155            raise ValueError(f"No such sea. Check available options: {self.seas}")
156
157        start = self._convert_date(start)
158        end = self._convert_date(end)
159
160        filenames = self._hf.get_filenames(start=start, end=end, step=step)
161        with ThreadPoolExecutor(max_workers=threads) as tpool:
162            raw_files = list(tpool.map(self._get_raw_file, filenames))
163
164        with ProcessPoolExecutor(max_workers=processes) as ppool:
165            arrays = list(
166                ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files)
167            )
168
169        # numpy matrix values are ints in range 0...100
170        result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0
171
172        if tensor_out:
173            result = torch.from_numpy(result)
174
175        if idx_out:
176            dates = [get_date_from_filename_template(f) for f in filenames]
177            return dates, result
178
179        return result
180
181    def _decode_and_crop(self, raw: bytes, sea: str | None):
182        matrix = self._decode_raw_matrix(raw)
183        if sea is None:
184            return matrix
185        return self._get_sea_by_name(sea, matrix)
186
187    def _get_sea_by_name(self, sea: str, matrix: np.ndarray) -> np.array:
188        sea_id = self._sea_map[sea]
189        boolean_mask = self._sea_mask == sea_id
190
191        rows = np.any(boolean_mask, axis=1)
192        cols = np.any(boolean_mask, axis=0)
193        rmin, rmax = np.where(rows)[0][[0, -1]]
194        cmin, cmax = np.where(cols)[0][[0, -1]]
195
196        cropped_sea = matrix[rmin : rmax + 1, cmin : cmax + 1]
197        return cropped_sea
198
199    def _get_raw_file(self, filename: str) -> bytes:
200        raw = self._hf.read_file(filename=filename)
201        if raw is None:
202            raise ValueError(f"Remote file {filename} not found")
203        return raw
204
205    def _decode_raw_matrix(self, raw: bytes) -> np.ndarray:
206        matrix: np.ndarray = np.load(BytesIO(raw))
207        if tuple(matrix.shape) != DATASET_SHAPE:
208            raise ValueError(
209                f"Matrix shape ({matrix.shape}) is not the same as a default one {DATASET_SHAPE=}"
210            )
211        return matrix
212
213    def _convert_date(self, d: str | date) -> date:
214        if isinstance(d, str):
215            return datetime.strptime(d, "%Y-%m-%d").date()
216        return d
class Loader:
 28class Loader:
 29    """
 30    Dataset Loader with a Hugging Face dataset client.
 31
 32    Downloading a large number of files in parallel may lead to
 33    request timeouts or temporary server-side errors from
 34    Hugging Face. If this happens, reduce the number of threads
 35    or split the download into smaller date ranges.
 36    """
 37
 38    def __init__(self):
 39        self._hf = HfDatasetClient()
 40
 41        sea_csv_reader = csv.DictReader(
 42            io.StringIO(self._get_raw_file(MASK_SEA_IDX_PATH).decode("utf-8"))
 43        )
 44        self._sea_map: dict[str, int] = {
 45            row[MASK_SEA_NAME_COLUMN]: int(row[MASK_SEA_NAME_ID])
 46            for row in sea_csv_reader
 47        }
 48
 49        self._sea_mask: np.ndarray = self._decode_raw_matrix(
 50            self._get_raw_file(MASK_SEA_DATA_PATH)
 51        )
 52        self._sea_mask[self._sea_mask == MASK_SEA_DATA_MAX_VALUE] = np.nan
 53
 54    @property
 55    def seas(self) -> tuple[str, ...]:
 56        """
 57        Return available seas.
 58        """
 59        return tuple(self._sea_map.keys())
 60
 61    @property
 62    def shape(self) -> tuple[int, ...]:
 63        """
 64        Return shape of a single dataset sample.
 65        """
 66        return self._hf.shape
 67
 68    @property
 69    def dataset_start(self) -> date:
 70        """
 71        Return earliest available date in the dataset.
 72        """
 73        return self._hf.dataset_start
 74
 75    @property
 76    def dataset_end(self) -> date:
 77        """
 78        Return latest available date in the dataset.
 79        """
 80        return self._hf.dataset_end
 81
 82    def info(self, per_year: bool = False) -> dict[str, any]:
 83        """
 84        Collect dataset statistics.
 85
 86        Args:
 87            per_year (bool): If True, include per-year statistics.
 88        """
 89        return self._hf.info(per_year=per_year)
 90
 91    def download(
 92        self,
 93        local_dir: str,
 94        start: date | str | None = None,
 95        end: date | str | None = None,
 96        step: int | str | None = None,
 97        threads: int = 16,
 98    ) -> list[str | None]:
 99        """
100        Download dataset files to a local directory in parallel.
101        Raw numpy matrices in the dataset have range values from 0 to 100.
102
103        Args:
104            local_dir (`str`): Directory to save downloaded files.
105            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
106            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
107            step (`int` or `str`, optional): Step between files. If `int` - number of days.
108                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
109                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
110                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
111                Defaults to 1 day.
112            threads (`int`, optional): Number of parallel download threads. Defaults to 24.
113        """
114        start = self._convert_date(start)
115        end = self._convert_date(end)
116
117        filenames = self._hf.get_filenames(start=start, end=end, step=step)
118        with ThreadPoolExecutor(max_workers=threads) as pool:
119            return list(
120                pool.map(
121                    lambda f: self._hf.download_file(filename=f, local_dir=local_dir),
122                    filenames,
123                )
124            )
125
126    def get(
127        self,
128        start: date | str | None = None,
129        end: date | str | None = None,
130        step: int | str | None = None,
131        sea: str | None = None,
132        tensor_out: bool = False,
133        idx_out: bool = False,
134        threads: int = 16,
135        processes: int | None = None,
136    ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx:
137        """
138        Load dataset files into memory as numpy arrays or torch tensors.
139        Loaded matrices are normalized to float values in the range 0 to 1.
140
141        Args:
142            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
143            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
144            step (`int` or `str`, optional): Step between files. If `int` - number of days.
145                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
146                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
147                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
148                Defaults to 1 day.
149            sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones.
150            tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False.
151            idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False.
152            threads (`int`, optional): Number of parallel download threads. Defaults to 16.
153            processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count.
154        """
155        if sea is not None and sea not in self._sea_map:
156            raise ValueError(f"No such sea. Check available options: {self.seas}")
157
158        start = self._convert_date(start)
159        end = self._convert_date(end)
160
161        filenames = self._hf.get_filenames(start=start, end=end, step=step)
162        with ThreadPoolExecutor(max_workers=threads) as tpool:
163            raw_files = list(tpool.map(self._get_raw_file, filenames))
164
165        with ProcessPoolExecutor(max_workers=processes) as ppool:
166            arrays = list(
167                ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files)
168            )
169
170        # numpy matrix values are ints in range 0...100
171        result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0
172
173        if tensor_out:
174            result = torch.from_numpy(result)
175
176        if idx_out:
177            dates = [get_date_from_filename_template(f) for f in filenames]
178            return dates, result
179
180        return result
181
182    def _decode_and_crop(self, raw: bytes, sea: str | None):
183        matrix = self._decode_raw_matrix(raw)
184        if sea is None:
185            return matrix
186        return self._get_sea_by_name(sea, matrix)
187
188    def _get_sea_by_name(self, sea: str, matrix: np.ndarray) -> np.array:
189        sea_id = self._sea_map[sea]
190        boolean_mask = self._sea_mask == sea_id
191
192        rows = np.any(boolean_mask, axis=1)
193        cols = np.any(boolean_mask, axis=0)
194        rmin, rmax = np.where(rows)[0][[0, -1]]
195        cmin, cmax = np.where(cols)[0][[0, -1]]
196
197        cropped_sea = matrix[rmin : rmax + 1, cmin : cmax + 1]
198        return cropped_sea
199
200    def _get_raw_file(self, filename: str) -> bytes:
201        raw = self._hf.read_file(filename=filename)
202        if raw is None:
203            raise ValueError(f"Remote file {filename} not found")
204        return raw
205
206    def _decode_raw_matrix(self, raw: bytes) -> np.ndarray:
207        matrix: np.ndarray = np.load(BytesIO(raw))
208        if tuple(matrix.shape) != DATASET_SHAPE:
209            raise ValueError(
210                f"Matrix shape ({matrix.shape}) is not the same as a default one {DATASET_SHAPE=}"
211            )
212        return matrix
213
214    def _convert_date(self, d: str | date) -> date:
215        if isinstance(d, str):
216            return datetime.strptime(d, "%Y-%m-%d").date()
217        return d

Dataset Loader with a Hugging Face dataset client.

Downloading a large number of files in parallel may lead to request timeouts or temporary server-side errors from Hugging Face. If this happens, reduce the number of threads or split the download into smaller date ranges.

seas: tuple[str, ...]
54    @property
55    def seas(self) -> tuple[str, ...]:
56        """
57        Return available seas.
58        """
59        return tuple(self._sea_map.keys())

Return available seas.

shape: tuple[int, ...]
61    @property
62    def shape(self) -> tuple[int, ...]:
63        """
64        Return shape of a single dataset sample.
65        """
66        return self._hf.shape

Return shape of a single dataset sample.

dataset_start: datetime.date
68    @property
69    def dataset_start(self) -> date:
70        """
71        Return earliest available date in the dataset.
72        """
73        return self._hf.dataset_start

Return earliest available date in the dataset.

dataset_end: datetime.date
75    @property
76    def dataset_end(self) -> date:
77        """
78        Return latest available date in the dataset.
79        """
80        return self._hf.dataset_end

Return latest available date in the dataset.

def info(self, per_year: bool = False) -> dict[str, any]:
82    def info(self, per_year: bool = False) -> dict[str, any]:
83        """
84        Collect dataset statistics.
85
86        Args:
87            per_year (bool): If True, include per-year statistics.
88        """
89        return self._hf.info(per_year=per_year)

Collect dataset statistics.

Arguments:
  • per_year (bool): If True, include per-year statistics.
def download( self, local_dir: str, start: datetime.date | str | None = None, end: datetime.date | str | None = None, step: int | str | None = None, threads: int = 16) -> list[str | None]:
 91    def download(
 92        self,
 93        local_dir: str,
 94        start: date | str | None = None,
 95        end: date | str | None = None,
 96        step: int | str | None = None,
 97        threads: int = 16,
 98    ) -> list[str | None]:
 99        """
100        Download dataset files to a local directory in parallel.
101        Raw numpy matrices in the dataset have range values from 0 to 100.
102
103        Args:
104            local_dir (`str`): Directory to save downloaded files.
105            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
106            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
107            step (`int` or `str`, optional): Step between files. If `int` - number of days.
108                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
109                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
110                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
111                Defaults to 1 day.
112            threads (`int`, optional): Number of parallel download threads. Defaults to 24.
113        """
114        start = self._convert_date(start)
115        end = self._convert_date(end)
116
117        filenames = self._hf.get_filenames(start=start, end=end, step=step)
118        with ThreadPoolExecutor(max_workers=threads) as pool:
119            return list(
120                pool.map(
121                    lambda f: self._hf.download_file(filename=f, local_dir=local_dir),
122                    filenames,
123                )
124            )

Download dataset files to a local directory in parallel. Raw numpy matrices in the dataset have range values from 0 to 100.

Arguments:
  • local_dir (str): Directory to save downloaded files.
  • start (date or str, optional): Start date for files. Defaults to earliest dataset date.
  • end (date or str, optional): End date for files. Defaults to latest dataset date.
  • step (int or str, optional): Step between files. If int - number of days. If str - format like "1d", "1w", "1m", "1y". For month or years steps ("1m", "2m", etc.), the date always lands on the last day of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). Defaults to 1 day.
  • threads (int, optional): Number of parallel download threads. Defaults to 24.
def get( self, start: datetime.date | str | None = None, end: datetime.date | str | None = None, step: int | str | None = None, sea: str | None = None, tensor_out: bool = False, idx_out: bool = False, threads: int = 16, processes: int | None = None) -> numpy.ndarray | torch.Tensor | tuple[list[datetime.date], numpy.ndarray] | tuple[list[datetime.date], torch.Tensor]:
126    def get(
127        self,
128        start: date | str | None = None,
129        end: date | str | None = None,
130        step: int | str | None = None,
131        sea: str | None = None,
132        tensor_out: bool = False,
133        idx_out: bool = False,
134        threads: int = 16,
135        processes: int | None = None,
136    ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx:
137        """
138        Load dataset files into memory as numpy arrays or torch tensors.
139        Loaded matrices are normalized to float values in the range 0 to 1.
140
141        Args:
142            start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date.
143            end (`date` or `str`, optional): End date for files. Defaults to latest dataset date.
144            step (`int` or `str`, optional): Step between files. If `int` - number of days.
145                If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`.
146                For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day
147                of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31).
148                Defaults to 1 day.
149            sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones.
150            tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False.
151            idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False.
152            threads (`int`, optional): Number of parallel download threads. Defaults to 16.
153            processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count.
154        """
155        if sea is not None and sea not in self._sea_map:
156            raise ValueError(f"No such sea. Check available options: {self.seas}")
157
158        start = self._convert_date(start)
159        end = self._convert_date(end)
160
161        filenames = self._hf.get_filenames(start=start, end=end, step=step)
162        with ThreadPoolExecutor(max_workers=threads) as tpool:
163            raw_files = list(tpool.map(self._get_raw_file, filenames))
164
165        with ProcessPoolExecutor(max_workers=processes) as ppool:
166            arrays = list(
167                ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files)
168            )
169
170        # numpy matrix values are ints in range 0...100
171        result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0
172
173        if tensor_out:
174            result = torch.from_numpy(result)
175
176        if idx_out:
177            dates = [get_date_from_filename_template(f) for f in filenames]
178            return dates, result
179
180        return result

Load dataset files into memory as numpy arrays or torch tensors. Loaded matrices are normalized to float values in the range 0 to 1.

Arguments:
  • start (date or str, optional): Start date for files. Defaults to earliest dataset date.
  • end (date or str, optional): End date for files. Defaults to latest dataset date.
  • step (int or str, optional): Step between files. If int - number of days. If str - format like "1d", "1w", "1m", "1y". For month or years steps ("1m", "2m", etc.), the date always lands on the last day of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). Defaults to 1 day.
  • sea (str, optional): Name of the sea (e.g., "Barents Sea"). Check Loader.seas for available ones.
  • tensor_out (bool, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False.
  • idx_out (bool, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False.
  • threads (int, optional): Number of parallel download threads. Defaults to 16.
  • processes (int, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count.