aiice.loader
1import csv 2import functools 3import io 4from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 5from datetime import date, datetime 6from io import BytesIO 7from typing import TypeAlias 8 9import numpy as np 10import torch 11 12from aiice.constants import ( 13 DATASET_SHAPE, 14 MASK_SEA_DATA_MAX_VALUE, 15 MASK_SEA_DATA_PATH, 16 MASK_SEA_IDX_PATH, 17 MASK_SEA_NAME_COLUMN, 18 MASK_SEA_NAME_ID, 19) 20from aiice.core.huggingface import HfDatasetClient 21from aiice.core.utils import get_date_from_filename_template 22 23NpWithIdx: TypeAlias = tuple[list[date], np.ndarray] 24TorchWithIdx: TypeAlias = tuple[list[date], torch.Tensor] 25 26 27class Loader: 28 """ 29 Dataset Loader with a Hugging Face dataset client. 30 31 Downloading a large number of files in parallel may lead to 32 request timeouts or temporary server-side errors from 33 Hugging Face. If this happens, reduce the number of threads 34 or split the download into smaller date ranges. 35 """ 36 37 def __init__(self): 38 self._hf = HfDatasetClient() 39 40 sea_csv_reader = csv.DictReader( 41 io.StringIO(self._get_raw_file(MASK_SEA_IDX_PATH).decode("utf-8")) 42 ) 43 self._sea_map: dict[str, int] = { 44 row[MASK_SEA_NAME_COLUMN]: int(row[MASK_SEA_NAME_ID]) 45 for row in sea_csv_reader 46 } 47 48 self._sea_mask: np.ndarray = self._decode_raw_matrix( 49 self._get_raw_file(MASK_SEA_DATA_PATH) 50 ) 51 self._sea_mask[self._sea_mask == MASK_SEA_DATA_MAX_VALUE] = np.nan 52 53 @property 54 def seas(self) -> tuple[str, ...]: 55 """ 56 Return available seas. 57 """ 58 return tuple(self._sea_map.keys()) 59 60 @property 61 def shape(self) -> tuple[int, ...]: 62 """ 63 Return shape of a single dataset sample. 64 """ 65 return self._hf.shape 66 67 @property 68 def dataset_start(self) -> date: 69 """ 70 Return earliest available date in the dataset. 71 """ 72 return self._hf.dataset_start 73 74 @property 75 def dataset_end(self) -> date: 76 """ 77 Return latest available date in the dataset. 78 """ 79 return self._hf.dataset_end 80 81 def info(self, per_year: bool = False) -> dict[str, any]: 82 """ 83 Collect dataset statistics. 84 85 Args: 86 per_year (bool): If True, include per-year statistics. 87 """ 88 return self._hf.info(per_year=per_year) 89 90 def download( 91 self, 92 local_dir: str, 93 start: date | str | None = None, 94 end: date | str | None = None, 95 step: int | str | None = None, 96 threads: int = 16, 97 ) -> list[str | None]: 98 """ 99 Download dataset files to a local directory in parallel. 100 Raw numpy matrices in the dataset have range values from 0 to 100. 101 102 Args: 103 local_dir (`str`): Directory to save downloaded files. 104 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 105 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 106 step (`int` or `str`, optional): Step between files. If `int` - number of days. 107 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 108 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 109 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 110 Defaults to 1 day. 111 threads (`int`, optional): Number of parallel download threads. Defaults to 24. 112 """ 113 start = self._convert_date(start) 114 end = self._convert_date(end) 115 116 filenames = self._hf.get_filenames(start=start, end=end, step=step) 117 with ThreadPoolExecutor(max_workers=threads) as pool: 118 return list( 119 pool.map( 120 lambda f: self._hf.download_file(filename=f, local_dir=local_dir), 121 filenames, 122 ) 123 ) 124 125 def get( 126 self, 127 start: date | str | None = None, 128 end: date | str | None = None, 129 step: int | str | None = None, 130 sea: str | None = None, 131 tensor_out: bool = False, 132 idx_out: bool = False, 133 threads: int = 16, 134 processes: int | None = None, 135 ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx: 136 """ 137 Load dataset files into memory as numpy arrays or torch tensors. 138 Loaded matrices are normalized to float values in the range 0 to 1. 139 140 Args: 141 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 142 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 143 step (`int` or `str`, optional): Step between files. If `int` - number of days. 144 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 145 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 146 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 147 Defaults to 1 day. 148 sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones. 149 tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False. 150 idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False. 151 threads (`int`, optional): Number of parallel download threads. Defaults to 16. 152 processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count. 153 """ 154 if sea is not None and sea not in self._sea_map: 155 raise ValueError(f"No such sea. Check available options: {self.seas}") 156 157 start = self._convert_date(start) 158 end = self._convert_date(end) 159 160 filenames = self._hf.get_filenames(start=start, end=end, step=step) 161 with ThreadPoolExecutor(max_workers=threads) as tpool: 162 raw_files = list(tpool.map(self._get_raw_file, filenames)) 163 164 with ProcessPoolExecutor(max_workers=processes) as ppool: 165 arrays = list( 166 ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files) 167 ) 168 169 # numpy matrix values are ints in range 0...100 170 result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0 171 172 if tensor_out: 173 result = torch.from_numpy(result) 174 175 if idx_out: 176 dates = [get_date_from_filename_template(f) for f in filenames] 177 return dates, result 178 179 return result 180 181 def _decode_and_crop(self, raw: bytes, sea: str | None): 182 matrix = self._decode_raw_matrix(raw) 183 if sea is None: 184 return matrix 185 return self._get_sea_by_name(sea, matrix) 186 187 def _get_sea_by_name(self, sea: str, matrix: np.ndarray) -> np.array: 188 sea_id = self._sea_map[sea] 189 boolean_mask = self._sea_mask == sea_id 190 191 rows = np.any(boolean_mask, axis=1) 192 cols = np.any(boolean_mask, axis=0) 193 rmin, rmax = np.where(rows)[0][[0, -1]] 194 cmin, cmax = np.where(cols)[0][[0, -1]] 195 196 cropped_sea = matrix[rmin : rmax + 1, cmin : cmax + 1] 197 return cropped_sea 198 199 def _get_raw_file(self, filename: str) -> bytes: 200 raw = self._hf.read_file(filename=filename) 201 if raw is None: 202 raise ValueError(f"Remote file {filename} not found") 203 return raw 204 205 def _decode_raw_matrix(self, raw: bytes) -> np.ndarray: 206 matrix: np.ndarray = np.load(BytesIO(raw)) 207 if tuple(matrix.shape) != DATASET_SHAPE: 208 raise ValueError( 209 f"Matrix shape ({matrix.shape}) is not the same as a default one {DATASET_SHAPE=}" 210 ) 211 return matrix 212 213 def _convert_date(self, d: str | date) -> date: 214 if isinstance(d, str): 215 return datetime.strptime(d, "%Y-%m-%d").date() 216 return d
class
Loader:
28class Loader: 29 """ 30 Dataset Loader with a Hugging Face dataset client. 31 32 Downloading a large number of files in parallel may lead to 33 request timeouts or temporary server-side errors from 34 Hugging Face. If this happens, reduce the number of threads 35 or split the download into smaller date ranges. 36 """ 37 38 def __init__(self): 39 self._hf = HfDatasetClient() 40 41 sea_csv_reader = csv.DictReader( 42 io.StringIO(self._get_raw_file(MASK_SEA_IDX_PATH).decode("utf-8")) 43 ) 44 self._sea_map: dict[str, int] = { 45 row[MASK_SEA_NAME_COLUMN]: int(row[MASK_SEA_NAME_ID]) 46 for row in sea_csv_reader 47 } 48 49 self._sea_mask: np.ndarray = self._decode_raw_matrix( 50 self._get_raw_file(MASK_SEA_DATA_PATH) 51 ) 52 self._sea_mask[self._sea_mask == MASK_SEA_DATA_MAX_VALUE] = np.nan 53 54 @property 55 def seas(self) -> tuple[str, ...]: 56 """ 57 Return available seas. 58 """ 59 return tuple(self._sea_map.keys()) 60 61 @property 62 def shape(self) -> tuple[int, ...]: 63 """ 64 Return shape of a single dataset sample. 65 """ 66 return self._hf.shape 67 68 @property 69 def dataset_start(self) -> date: 70 """ 71 Return earliest available date in the dataset. 72 """ 73 return self._hf.dataset_start 74 75 @property 76 def dataset_end(self) -> date: 77 """ 78 Return latest available date in the dataset. 79 """ 80 return self._hf.dataset_end 81 82 def info(self, per_year: bool = False) -> dict[str, any]: 83 """ 84 Collect dataset statistics. 85 86 Args: 87 per_year (bool): If True, include per-year statistics. 88 """ 89 return self._hf.info(per_year=per_year) 90 91 def download( 92 self, 93 local_dir: str, 94 start: date | str | None = None, 95 end: date | str | None = None, 96 step: int | str | None = None, 97 threads: int = 16, 98 ) -> list[str | None]: 99 """ 100 Download dataset files to a local directory in parallel. 101 Raw numpy matrices in the dataset have range values from 0 to 100. 102 103 Args: 104 local_dir (`str`): Directory to save downloaded files. 105 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 106 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 107 step (`int` or `str`, optional): Step between files. If `int` - number of days. 108 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 109 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 110 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 111 Defaults to 1 day. 112 threads (`int`, optional): Number of parallel download threads. Defaults to 24. 113 """ 114 start = self._convert_date(start) 115 end = self._convert_date(end) 116 117 filenames = self._hf.get_filenames(start=start, end=end, step=step) 118 with ThreadPoolExecutor(max_workers=threads) as pool: 119 return list( 120 pool.map( 121 lambda f: self._hf.download_file(filename=f, local_dir=local_dir), 122 filenames, 123 ) 124 ) 125 126 def get( 127 self, 128 start: date | str | None = None, 129 end: date | str | None = None, 130 step: int | str | None = None, 131 sea: str | None = None, 132 tensor_out: bool = False, 133 idx_out: bool = False, 134 threads: int = 16, 135 processes: int | None = None, 136 ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx: 137 """ 138 Load dataset files into memory as numpy arrays or torch tensors. 139 Loaded matrices are normalized to float values in the range 0 to 1. 140 141 Args: 142 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 143 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 144 step (`int` or `str`, optional): Step between files. If `int` - number of days. 145 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 146 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 147 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 148 Defaults to 1 day. 149 sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones. 150 tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False. 151 idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False. 152 threads (`int`, optional): Number of parallel download threads. Defaults to 16. 153 processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count. 154 """ 155 if sea is not None and sea not in self._sea_map: 156 raise ValueError(f"No such sea. Check available options: {self.seas}") 157 158 start = self._convert_date(start) 159 end = self._convert_date(end) 160 161 filenames = self._hf.get_filenames(start=start, end=end, step=step) 162 with ThreadPoolExecutor(max_workers=threads) as tpool: 163 raw_files = list(tpool.map(self._get_raw_file, filenames)) 164 165 with ProcessPoolExecutor(max_workers=processes) as ppool: 166 arrays = list( 167 ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files) 168 ) 169 170 # numpy matrix values are ints in range 0...100 171 result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0 172 173 if tensor_out: 174 result = torch.from_numpy(result) 175 176 if idx_out: 177 dates = [get_date_from_filename_template(f) for f in filenames] 178 return dates, result 179 180 return result 181 182 def _decode_and_crop(self, raw: bytes, sea: str | None): 183 matrix = self._decode_raw_matrix(raw) 184 if sea is None: 185 return matrix 186 return self._get_sea_by_name(sea, matrix) 187 188 def _get_sea_by_name(self, sea: str, matrix: np.ndarray) -> np.array: 189 sea_id = self._sea_map[sea] 190 boolean_mask = self._sea_mask == sea_id 191 192 rows = np.any(boolean_mask, axis=1) 193 cols = np.any(boolean_mask, axis=0) 194 rmin, rmax = np.where(rows)[0][[0, -1]] 195 cmin, cmax = np.where(cols)[0][[0, -1]] 196 197 cropped_sea = matrix[rmin : rmax + 1, cmin : cmax + 1] 198 return cropped_sea 199 200 def _get_raw_file(self, filename: str) -> bytes: 201 raw = self._hf.read_file(filename=filename) 202 if raw is None: 203 raise ValueError(f"Remote file {filename} not found") 204 return raw 205 206 def _decode_raw_matrix(self, raw: bytes) -> np.ndarray: 207 matrix: np.ndarray = np.load(BytesIO(raw)) 208 if tuple(matrix.shape) != DATASET_SHAPE: 209 raise ValueError( 210 f"Matrix shape ({matrix.shape}) is not the same as a default one {DATASET_SHAPE=}" 211 ) 212 return matrix 213 214 def _convert_date(self, d: str | date) -> date: 215 if isinstance(d, str): 216 return datetime.strptime(d, "%Y-%m-%d").date() 217 return d
Dataset Loader with a Hugging Face dataset client.
Downloading a large number of files in parallel may lead to request timeouts or temporary server-side errors from Hugging Face. If this happens, reduce the number of threads or split the download into smaller date ranges.
seas: tuple[str, ...]
54 @property 55 def seas(self) -> tuple[str, ...]: 56 """ 57 Return available seas. 58 """ 59 return tuple(self._sea_map.keys())
Return available seas.
shape: tuple[int, ...]
61 @property 62 def shape(self) -> tuple[int, ...]: 63 """ 64 Return shape of a single dataset sample. 65 """ 66 return self._hf.shape
Return shape of a single dataset sample.
dataset_start: datetime.date
68 @property 69 def dataset_start(self) -> date: 70 """ 71 Return earliest available date in the dataset. 72 """ 73 return self._hf.dataset_start
Return earliest available date in the dataset.
dataset_end: datetime.date
75 @property 76 def dataset_end(self) -> date: 77 """ 78 Return latest available date in the dataset. 79 """ 80 return self._hf.dataset_end
Return latest available date in the dataset.
def
info(self, per_year: bool = False) -> dict[str, any]:
82 def info(self, per_year: bool = False) -> dict[str, any]: 83 """ 84 Collect dataset statistics. 85 86 Args: 87 per_year (bool): If True, include per-year statistics. 88 """ 89 return self._hf.info(per_year=per_year)
Collect dataset statistics.
Arguments:
- per_year (bool): If True, include per-year statistics.
def
download( self, local_dir: str, start: datetime.date | str | None = None, end: datetime.date | str | None = None, step: int | str | None = None, threads: int = 16) -> list[str | None]:
91 def download( 92 self, 93 local_dir: str, 94 start: date | str | None = None, 95 end: date | str | None = None, 96 step: int | str | None = None, 97 threads: int = 16, 98 ) -> list[str | None]: 99 """ 100 Download dataset files to a local directory in parallel. 101 Raw numpy matrices in the dataset have range values from 0 to 100. 102 103 Args: 104 local_dir (`str`): Directory to save downloaded files. 105 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 106 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 107 step (`int` or `str`, optional): Step between files. If `int` - number of days. 108 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 109 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 110 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 111 Defaults to 1 day. 112 threads (`int`, optional): Number of parallel download threads. Defaults to 24. 113 """ 114 start = self._convert_date(start) 115 end = self._convert_date(end) 116 117 filenames = self._hf.get_filenames(start=start, end=end, step=step) 118 with ThreadPoolExecutor(max_workers=threads) as pool: 119 return list( 120 pool.map( 121 lambda f: self._hf.download_file(filename=f, local_dir=local_dir), 122 filenames, 123 ) 124 )
Download dataset files to a local directory in parallel. Raw numpy matrices in the dataset have range values from 0 to 100.
Arguments:
- local_dir (
str): Directory to save downloaded files. - start (
dateorstr, optional): Start date for files. Defaults to earliest dataset date. - end (
dateorstr, optional): End date for files. Defaults to latest dataset date. - step (
intorstr, optional): Step between files. Ifint- number of days. Ifstr- format like"1d","1w","1m","1y". For month or years steps ("1m","2m", etc.), the date always lands on the last day of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). Defaults to 1 day. - threads (
int, optional): Number of parallel download threads. Defaults to 24.
def
get( self, start: datetime.date | str | None = None, end: datetime.date | str | None = None, step: int | str | None = None, sea: str | None = None, tensor_out: bool = False, idx_out: bool = False, threads: int = 16, processes: int | None = None) -> numpy.ndarray | torch.Tensor | tuple[list[datetime.date], numpy.ndarray] | tuple[list[datetime.date], torch.Tensor]:
126 def get( 127 self, 128 start: date | str | None = None, 129 end: date | str | None = None, 130 step: int | str | None = None, 131 sea: str | None = None, 132 tensor_out: bool = False, 133 idx_out: bool = False, 134 threads: int = 16, 135 processes: int | None = None, 136 ) -> np.ndarray | torch.Tensor | NpWithIdx | TorchWithIdx: 137 """ 138 Load dataset files into memory as numpy arrays or torch tensors. 139 Loaded matrices are normalized to float values in the range 0 to 1. 140 141 Args: 142 start (`date` or `str`, optional): Start date for files. Defaults to earliest dataset date. 143 end (`date` or `str`, optional): End date for files. Defaults to latest dataset date. 144 step (`int` or `str`, optional): Step between files. If `int` - number of days. 145 If `str` - format like `"1d"`, `"1w"`, `"1m"`, `"1y"`. 146 For month or years steps (`"1m"`, `"2m"`, etc.), the date always lands on the last day 147 of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). 148 Defaults to 1 day. 149 sea (`str`, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones. 150 tensor_out (`bool`, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False. 151 idx_out (`bool`, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False. 152 threads (`int`, optional): Number of parallel download threads. Defaults to 16. 153 processes (`int`, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count. 154 """ 155 if sea is not None and sea not in self._sea_map: 156 raise ValueError(f"No such sea. Check available options: {self.seas}") 157 158 start = self._convert_date(start) 159 end = self._convert_date(end) 160 161 filenames = self._hf.get_filenames(start=start, end=end, step=step) 162 with ThreadPoolExecutor(max_workers=threads) as tpool: 163 raw_files = list(tpool.map(self._get_raw_file, filenames)) 164 165 with ProcessPoolExecutor(max_workers=processes) as ppool: 166 arrays = list( 167 ppool.map(functools.partial(self._decode_and_crop, sea=sea), raw_files) 168 ) 169 170 # numpy matrix values are ints in range 0...100 171 result: np.ndarray | torch.Tensor = np.stack(arrays).astype(np.float32) / 100.0 172 173 if tensor_out: 174 result = torch.from_numpy(result) 175 176 if idx_out: 177 dates = [get_date_from_filename_template(f) for f in filenames] 178 return dates, result 179 180 return result
Load dataset files into memory as numpy arrays or torch tensors. Loaded matrices are normalized to float values in the range 0 to 1.
Arguments:
- start (
dateorstr, optional): Start date for files. Defaults to earliest dataset date. - end (
dateorstr, optional): End date for files. Defaults to latest dataset date. - step (
intorstr, optional): Step between files. Ifint- number of days. Ifstr- format like"1d","1w","1m","1y". For month or years steps ("1m","2m", etc.), the date always lands on the last day of the month (e.g., Jan 31 + 1 month = Feb 28/29, then Mar 31). Defaults to 1 day. - sea (
str, optional): Name of the sea (e.g., "Barents Sea"). CheckLoader.seasfor available ones. - tensor_out (
bool, optional): If True, returns a torch.Tensor instead of numpy array. Defaults to False. - idx_out (
bool, optional): If True, returns a tuple of (date indexes, matrices). Defaults to False. - threads (
int, optional): Number of parallel download threads. Defaults to 16. - processes (
int, optional): Number of worker processes for decoding raw bytes. Defaults to CPU core count.