aiice
AIICE is an open-source Python framework designed as a standardized benchmark for spatio-temporal forecasting of Arctic sea ice concentration. It provides reproducible pipelines for loading, preprocessing, and evaluating satellite-derived OSI-SAF data, supporting both short- and long-term prediction horizons
Installation
The simplest way to install framework with pip:
pip install aiice-bench
Quickstart
The AIICE class provides a simple interface for loading Arctic ice data, preparing datasets, and benchmarking PyTorch models:
from aiice import AIICE
# Initialize AIICE with a sliding window
# of past 30 days and forecast of 7 days
aiice = AIICE(
pre_history_len=30,
forecast_len=7,
batch_size=32,
start="2022-01-01",
end="2022-12-31"
)
# Define your PyTorch model
model = MyModel()
# Run benchmarking to compute metrics on the dataset
report = aiice.bench(model)
print(report)
Check package doc and see more usage examples. You can also explore the raw dataset and work with it independently via Hugging Face
Leaderboard
The leaderboard reports the mean performance of each model across the evaluation dataset. You can check models' setup in examples.
| baseline_mean | baseline_repeat | conv2d | conv3d | ||
|---|---|---|---|---|---|
| Barents Sea | bin_accuracy | 0.874963 | 0.848936 | 0.937071 | 0.891255 |
| iou | 0.185126 | 0.331170 | 0.647688 | 0.420801 | |
| mae | 0.130236 | 0.151377 | 0.067575 | 0.113846 | |
| mse | 0.053554 | 0.106431 | 0.028444 | 0.064654 | |
| psnr | 12.712070 | 9.729317 | 15.460110 | 11.894089 | |
| rmse | 0.231418 | 0.326238 | 0.168653 | 0.254271 | |
| ssim | 0.540464 | 0.609196 | 0.696043 | 0.618139 | |
| Chukchi Sea | bin_accuracy | 0.656515 | 0.675528 | 0.947459 | 0.789110 |
| iou | 0.126601 | 0.364351 | 0.865943 | 0.585862 | |
| mae | 0.269926 | 0.300754 | 0.069100 | 0.198657 | |
| mse | 0.124306 | 0.246038 | 0.023475 | 0.125997 | |
| psnr | 9.055069 | 6.089983 | 16.293947 | 8.996499 | |
| rmse | 0.352571 | 0.496022 | 0.153215 | 0.354958 | |
| ssim | 0.405798 | 0.385161 | 0.651510 | 0.449680 | |
| Kara Sea | bin_accuracy | 0.801598 | 0.797711 | 0.939550 | 0.844245 |
| iou | 0.282630 | 0.412451 | 0.785852 | 0.559398 | |
| mae | 0.162785 | 0.185920 | 0.065702 | 0.149524 | |
| mse | 0.070723 | 0.136968 | 0.025262 | 0.092185 | |
| psnr | 11.504373 | 8.633821 | 15.975368 | 10.358038 | |
| rmse | 0.265939 | 0.370091 | 0.158939 | 0.303539 | |
| ssim | 0.604080 | 0.590542 | 0.725831 | 0.589535 | |
| Laptev Sea | bin_accuracy | 0.839829 | 0.863018 | 0.964288 | 0.897629 |
| iou | 0.387533 | 0.534633 | 0.859092 | 0.683309 | |
| mae | 0.115111 | 0.122237 | 0.043340 | 0.094628 | |
| mse | 0.051770 | 0.094377 | 0.015273 | 0.066438 | |
| psnr | 12.859248 | 10.251351 | 18.160892 | 11.784326 | |
| rmse | 0.227529 | 0.307208 | 0.123582 | 0.257630 | |
| ssim | 0.782073 | 0.746823 | 0.837163 | 0.802543 | |
| Sea of Japan | bin_accuracy | 0.994356 | 0.989473 | 0.994356 | 0.995731 |
| iou | 0.000000 | 0.035046 | 0.000000 | 0.000000 | |
| mae | 0.013824 | 0.016332 | 0.009841 | 0.008582 | |
| mse | 0.004467 | 0.009577 | 0.005990 | 0.004908 | |
| psnr | 23.499943 | 20.187490 | 22.225956 | 22.945065 | |
| rmse | 0.066835 | 0.097865 | 0.077393 | 0.069567 | |
| ssim | 0.841847 | 0.879064 | 0.922021 | 0.919443 |
How to contribute?
We welcome bug reports, feature requests, and pull requests. This project uses uv for Python version, dependency, and project management.
Development
We try to maintain good practices of readable open source code. Therefore, if you want to participate in the development and open your pool request, pay attention to the following points:
To install the project with all development dependencies, run:
uv sync --locked --all-extras --dev --group=scriptsYou can also use
pipin your own Python environment, but usinguvis the preferred way cause of possible dependency resolve problems.# dev version pip install -e ".[dev]" # scripts version for running examples pip install -e ".[scripts]"Before committing or pushing changes run the formatters from the repository root:
uvx isort src tests && uvx black src testsTo run tests locally with coverage enabled:
uv run pytest --cov=. --cov-branch testsTo add new dependencies, run:
# prod version uv add <new-package> # dev version uv add --dev <new-package> # scripts version uv add --group=scripts <new-package>To buid and run docs locally, run:
uv run pdoc --math -d google --no-include-undocumented -t .doc/ aiiceTo run any debug scripts with the project env, run:
uv run <script.py> --group=scriptsTo run Jupyter notebooks with the project env, run:
uv run --group=scripts --with jupyter jupyter lab
General tips
- Prefer contributing via forks, especially for external contributors
- Give only appropriate names to commits / issues / pull-requests
Release process
Despite the fact that the framework is very small, we want to maintain its consistency. The release procedure looks like this:
- pull-request is approved by maintainers and merged with squashing commits
- a new tag is being released to the github repository and pypi with GitHub Actions
Documentation
1""" 2.. include:: ../../README.md 3 :start-line: 1 4<!-- The comment bellow is required as a marker for pdoc. See .doc/module.html.jinja2 --> 5<!-- MAIN_README_PDOC --> 6.. include:: ../../CONTRIBUTE.md 7# Documentation 8""" 9 10from aiice import core, loader, metrics, preprocess 11from aiice.benchmark import AIICE 12 13# visible modules to pdoc 14__all__ = ["AIICE", "core", "loader", "metrics", "preprocess"] 15 16from importlib.metadata import PackageNotFoundError, version 17 18try: 19 __version__ = version("aiice-bench") 20except PackageNotFoundError: 21 __version__ = "0.0.0"
20class AIICE: 21 """ 22 High-level interface for loading Arctic ice data, preparing datasets, and benchmarking models. 23 24 This class provides a simple API to: 25 1. Load historical ice data within a specified date range (see `aiice.loader.Loader`) 26 2. Convert the data into sliding-window datasets (see `aiice.preprocess.SlidingWindowDataset`) 27 3. Create a PyTorch DataLoader for batch processing 28 4. Benchmark any PyTorch model on the OSI-SAF dataset with specified metrics 29 30 Args: 31 pre_history_len (int): Number of past time steps to include in each input sample (X). 32 forecast_len (int): Number of future time steps to predict (Y) in each sample. 33 batch_size (int, optional): Batch size for the DataLoader. Defaults to 16. 34 start (date | str | None, optional): Start date of the data to load. If None, defaults to the earliest available data. 35 end (date | str | None, optional): End date of the data to load. If None, defaults to the latest available data. 36 step (int | None, optional): Step in days between data points. Defaults to 1 if not provided. 37 sea (str, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones. 38 threshold (float | None, optional): Threshold for binarizing the target Y. Values above threshold are set to 1, below or equal set to 0. Defaults to None. 39 x_binarize (bool, optional): Whether to apply the same threshold binarization to input X. Defaults to False. 40 threads (int, optional): Number of parallel download threads. You can reduce this value in case of rate limiting HuggingFace API errors. Defaults to 16. 41 device (str | None, optional): Device to place tensors on ("cpu", "cuda", etc.). If None, uses PyTorch default device. 42 43 Example: 44 >>> aiice = AIICE(pre_history_len=30, forecast_len=7, batch_size=32, start="2022-01-01", end="2022-12-31") 45 >>> model = MyModel() 46 >>> results = aiice.bench(model, metrics={"mae", "psnr"}) 47 """ 48 49 def __init__( 50 self, 51 pre_history_len: int, 52 forecast_len: int, 53 batch_size: int = 16, 54 start: date | str | None = None, 55 end: date | str | None = None, 56 step: int | None = None, 57 sea: str | None = None, 58 threshold: float | None = None, 59 x_binarize: bool = False, 60 threads: int = 16, 61 device: str | None = None, 62 ): 63 self._device = device 64 self._sea = sea 65 66 raw_data = Loader().get( 67 start=start, 68 end=end, 69 step=step, 70 sea=sea, 71 threads=threads, 72 tensor_out=True, 73 idx_out=True, 74 ) 75 76 indices = raw_data[0] 77 matrices = raw_data[1] 78 79 dataset = SlidingWindowDataset( 80 data=matrices, 81 idx=indices, 82 pre_history_len=pre_history_len, 83 forecast_len=forecast_len, 84 threshold=threshold, 85 x_binarize=x_binarize, 86 device=self._device, 87 ) 88 89 self._dataloader = DataLoader( 90 dataset=dataset, 91 batch_size=batch_size, 92 collate_fn=self._default_collate_fn, 93 ) 94 95 def bench( 96 self, 97 model: nn.Module, 98 metrics: dict[str, MetricFn] | list[str] | None = None, 99 path: str | None = None, 100 detailed: bool = True, 101 plot_workers: int = 4, 102 fps: int = 2, 103 ) -> dict[str, list[float]]: 104 """ 105 Run benchmarking evaluation of a model on the prepared dataset. 106 107 The method iterates over the internal DataLoader, generates model 108 predictions, computes evaluation metrics, and optionally produces 109 visualization GIFs comparing ground truth and predicted forecasts. 110 111 When `path` is provided, visualization generation is executed 112 asynchronously using a thread pool so that plotting does not block 113 model inference. 114 115 Args: 116 model (nn.Module): 117 PyTorch model used to generate predictions. The model is expected 118 to accept inputs `x` with shape `(batch, pre_history_len, ...)` 119 and return predictions compatible with the selected metrics. 120 121 metrics (dict[str, MetricFn] | list[str], optional): 122 Metrics to compute during evaluation. If a list of metric names is 123 provided, the metrics are resolved from the built-in registry. 124 If `None`, default metrics are used. 125 See `aiice.metrics.Evaluator` for details. 126 127 path (str, optional): 128 Directory where forecast visualizations will be saved. 129 If provided, each sample in the dataset will produce a GIF 130 animation showing the forecast horizon, comparing ground truth 131 and model predictions frame by frame. 132 133 The files are named: `<start_forecast_date>_<end_forecast_date>.gif` 134 If `None`, visualization generation is skipped. 135 136 detailed (bool, optional): 137 If True, returns full statistics for each metric like 138 mean, last value, count, min, and max. 139 If False, returns only the mean value per metric. 140 141 plot_workers (int, optional): 142 Number of worker threads used for asynchronous plot generation. 143 Increasing this value can speed up visualization when many samples 144 are processed. Defaults to 4. 145 146 fps (int, optional): 147 Frames per second of the generated GIF animations. Defaults to 2. 148 149 Returns: 150 dict[str, list[float]]: 151 Aggregated metric results returned by the evaluator. 152 """ 153 if path is not None: 154 os.makedirs(path, exist_ok=True) 155 executor = ThreadPoolExecutor(max_workers=plot_workers) 156 futures = [] 157 158 evaluator = Evaluator(metrics=metrics, accumulate=True) 159 160 model.eval() 161 with torch.no_grad(): 162 for batch in tqdm(self._dataloader, desc="Prediction"): 163 dates, x, y = batch 164 x, y = x.to(self._device), y.to(self._device) 165 166 pred = model(x) 167 evaluator.eval(y, pred) 168 169 if path is None: 170 continue 171 172 futures.append( 173 executor.submit( 174 self._save_batch_plot, 175 sea=self._sea, 176 path=path, 177 dates=dates, 178 y=y.detach().cpu().numpy(), 179 pred=pred.detach().cpu().numpy(), 180 fps=fps, 181 ) 182 ) 183 184 if path is not None: 185 for f in tqdm(futures, desc="Saving plots"): 186 f.result() 187 executor.shutdown(wait=True) 188 189 return evaluator.report(detailed=detailed) 190 191 @staticmethod 192 def _save_batch_plot( 193 sea: str | None, 194 path: str, 195 dates: list[list[date]], 196 y: np.ndarray, 197 pred: np.ndarray, 198 fps: int, 199 ) -> None: 200 """ 201 Generate GIF visualizations for a batch of forecast samples. 202 203 For each sample in the batch, a GIF animation is created showing 204 the temporal evolution of the forecast horizon. Each frame displays 205 a side-by-side comparison between the ground truth ice map and the 206 model prediction for the corresponding forecast date. 207 208 The resulting GIF file is saved to `path` with the name: `<start_forecast_date>_<end_forecast_date>.gif` 209 where the dates correspond to the forecast window of the sample. 210 """ 211 matplotlib.use("Agg") 212 213 batch_size, forecast_len = y.shape[:2] 214 for i in range(batch_size): 215 216 start_date = dates[i][-forecast_len].strftime("%d-%m-%Y") 217 end_date = dates[i][-1].strftime("%d-%m-%Y") 218 219 save_path = os.path.join(path, f"{start_date}_{end_date}.gif") 220 fig, axes = plt.subplots(1, 2, figsize=(8, 4)) 221 222 im_gt = axes[0].imshow(y[i, 0]) 223 axes[0].set_title("Ground Truth") 224 axes[0].axis("off") 225 226 im_pred = axes[1].imshow(pred[i, 0]) 227 axes[1].set_title("Prediction") 228 axes[1].axis("off") 229 230 frames = [] 231 for j in range(forecast_len): 232 233 im_gt.set_data(y[i, j]) 234 im_pred.set_data(pred[i, j]) 235 236 forecast_date = dates[i][-forecast_len + j] 237 if sea is None: 238 fig.suptitle(f"Forecast: {forecast_date.strftime('%d-%m-%Y')}") 239 else: 240 fig.suptitle( 241 f"{sea} | Forecast: {forecast_date.strftime('%d-%m-%Y')}" 242 ) 243 244 fig.canvas.draw() 245 frame = np.asarray(fig.canvas.buffer_rgba())[:, :, :3].copy() 246 frames.append(frame) 247 248 plt.close(fig) 249 imageio.mimsave(save_path, frames, duration=1 / fps, loop=0) 250 251 @staticmethod 252 def _default_collate_fn( 253 batch: list[tuple[list[date], torch.Tensor, torch.Tensor]], 254 ) -> tuple[list[list[date]], torch.Tensor, torch.Tensor]: 255 """ 256 Collates SlidingWindow dataset samples into a batch 257 input -> batch of samples 258 output -> batched tensors + list of date sequences 259 260 Example: 261 ``` 262 d1 = [date1...date2] 263 x1.shape = (T, H, W) 264 y1.shape = (H, W) 265 266 batch = [ 267 (d1, x1, y1), 268 (d2, x2, y2) 269 ] 270 271 Output: 272 dates -> [d1, d2] 273 x -> torch.Tensor (B, T, H, W) 274 y -> torch.Tensor (B, H, W) 275 ``` 276 """ 277 dates, x, y = zip(*batch) 278 return list(dates), torch.stack(x), torch.stack(y)
High-level interface for loading Arctic ice data, preparing datasets, and benchmarking models.
This class provides a simple API to:
- Load historical ice data within a specified date range (see
aiice.loader.Loader) - Convert the data into sliding-window datasets (see
aiice.preprocess.SlidingWindowDataset) - Create a PyTorch DataLoader for batch processing
- Benchmark any PyTorch model on the OSI-SAF dataset with specified metrics
Arguments:
- pre_history_len (int): Number of past time steps to include in each input sample (X).
- forecast_len (int): Number of future time steps to predict (Y) in each sample.
- batch_size (int, optional): Batch size for the DataLoader. Defaults to 16.
- start (date | str | None, optional): Start date of the data to load. If None, defaults to the earliest available data.
- end (date | str | None, optional): End date of the data to load. If None, defaults to the latest available data.
- step (int | None, optional): Step in days between data points. Defaults to 1 if not provided.
- sea (str, optional): Name of the sea (e.g., "Barents Sea"). Check
Loader.seasfor available ones. - threshold (float | None, optional): Threshold for binarizing the target Y. Values above threshold are set to 1, below or equal set to 0. Defaults to None.
- x_binarize (bool, optional): Whether to apply the same threshold binarization to input X. Defaults to False.
- threads (int, optional): Number of parallel download threads. You can reduce this value in case of rate limiting HuggingFace API errors. Defaults to 16.
- device (str | None, optional): Device to place tensors on ("cpu", "cuda", etc.). If None, uses PyTorch default device.
Example:
>>> aiice = AIICE(pre_history_len=30, forecast_len=7, batch_size=32, start="2022-01-01", end="2022-12-31") >>> model = MyModel() >>> results = aiice.bench(model, metrics={"mae", "psnr"})
95 def bench( 96 self, 97 model: nn.Module, 98 metrics: dict[str, MetricFn] | list[str] | None = None, 99 path: str | None = None, 100 detailed: bool = True, 101 plot_workers: int = 4, 102 fps: int = 2, 103 ) -> dict[str, list[float]]: 104 """ 105 Run benchmarking evaluation of a model on the prepared dataset. 106 107 The method iterates over the internal DataLoader, generates model 108 predictions, computes evaluation metrics, and optionally produces 109 visualization GIFs comparing ground truth and predicted forecasts. 110 111 When `path` is provided, visualization generation is executed 112 asynchronously using a thread pool so that plotting does not block 113 model inference. 114 115 Args: 116 model (nn.Module): 117 PyTorch model used to generate predictions. The model is expected 118 to accept inputs `x` with shape `(batch, pre_history_len, ...)` 119 and return predictions compatible with the selected metrics. 120 121 metrics (dict[str, MetricFn] | list[str], optional): 122 Metrics to compute during evaluation. If a list of metric names is 123 provided, the metrics are resolved from the built-in registry. 124 If `None`, default metrics are used. 125 See `aiice.metrics.Evaluator` for details. 126 127 path (str, optional): 128 Directory where forecast visualizations will be saved. 129 If provided, each sample in the dataset will produce a GIF 130 animation showing the forecast horizon, comparing ground truth 131 and model predictions frame by frame. 132 133 The files are named: `<start_forecast_date>_<end_forecast_date>.gif` 134 If `None`, visualization generation is skipped. 135 136 detailed (bool, optional): 137 If True, returns full statistics for each metric like 138 mean, last value, count, min, and max. 139 If False, returns only the mean value per metric. 140 141 plot_workers (int, optional): 142 Number of worker threads used for asynchronous plot generation. 143 Increasing this value can speed up visualization when many samples 144 are processed. Defaults to 4. 145 146 fps (int, optional): 147 Frames per second of the generated GIF animations. Defaults to 2. 148 149 Returns: 150 dict[str, list[float]]: 151 Aggregated metric results returned by the evaluator. 152 """ 153 if path is not None: 154 os.makedirs(path, exist_ok=True) 155 executor = ThreadPoolExecutor(max_workers=plot_workers) 156 futures = [] 157 158 evaluator = Evaluator(metrics=metrics, accumulate=True) 159 160 model.eval() 161 with torch.no_grad(): 162 for batch in tqdm(self._dataloader, desc="Prediction"): 163 dates, x, y = batch 164 x, y = x.to(self._device), y.to(self._device) 165 166 pred = model(x) 167 evaluator.eval(y, pred) 168 169 if path is None: 170 continue 171 172 futures.append( 173 executor.submit( 174 self._save_batch_plot, 175 sea=self._sea, 176 path=path, 177 dates=dates, 178 y=y.detach().cpu().numpy(), 179 pred=pred.detach().cpu().numpy(), 180 fps=fps, 181 ) 182 ) 183 184 if path is not None: 185 for f in tqdm(futures, desc="Saving plots"): 186 f.result() 187 executor.shutdown(wait=True) 188 189 return evaluator.report(detailed=detailed)
Run benchmarking evaluation of a model on the prepared dataset.
The method iterates over the internal DataLoader, generates model predictions, computes evaluation metrics, and optionally produces visualization GIFs comparing ground truth and predicted forecasts.
When path is provided, visualization generation is executed
asynchronously using a thread pool so that plotting does not block
model inference.
Arguments:
- model (nn.Module): PyTorch model used to generate predictions. The model is expected
to accept inputs
xwith shape(batch, pre_history_len, ...)and return predictions compatible with the selected metrics. - metrics (dict[str, MetricFn] | list[str], optional): Metrics to compute during evaluation. If a list of metric names is
provided, the metrics are resolved from the built-in registry.
If
None, default metrics are used. Seeaiice.metrics.Evaluatorfor details. path (str, optional): Directory where forecast visualizations will be saved. If provided, each sample in the dataset will produce a GIF animation showing the forecast horizon, comparing ground truth and model predictions frame by frame.
The files are named:
<start_forecast_date>_<end_forecast_date>.gifIfNone, visualization generation is skipped.- detailed (bool, optional): If True, returns full statistics for each metric like mean, last value, count, min, and max. If False, returns only the mean value per metric.
- plot_workers (int, optional): Number of worker threads used for asynchronous plot generation. Increasing this value can speed up visualization when many samples are processed. Defaults to 4.
- fps (int, optional): Frames per second of the generated GIF animations. Defaults to 2.
Returns:
dict[str, list[float]]: Aggregated metric results returned by the evaluator.