aiice

uv Hugging Face PyTorch NumPy


AIICE is an open-source Python framework designed as a standardized benchmark for spatio-temporal forecasting of Arctic sea ice concentration. It provides reproducible pipelines for loading, preprocessing, and evaluating satellite-derived OSI-SAF data, supporting both short- and long-term prediction horizons

Installation

The simplest way to install framework with pip:

pip install aiice-bench

Quickstart

The AIICE class provides a simple interface for loading Arctic ice data, preparing datasets, and benchmarking PyTorch models:

image

from aiice import AIICE

# Initialize AIICE with a sliding window 
# of past 30 days and forecast of 7 days
aiice = AIICE(
    pre_history_len=30,
    forecast_len=7,
    batch_size=32,
    start="2022-01-01",
    end="2022-12-31"
)

# Define your PyTorch model
model = MyModel()

# Run benchmarking to compute metrics on the dataset
report = aiice.bench(model)
print(report)

Check package doc and see more usage examples. You can also explore the raw dataset and work with it independently via Hugging Face

Leaderboard

The leaderboard reports the mean performance of each model across the evaluation dataset. You can check models' setup in examples.

baseline_mean baseline_repeat conv2d conv3d
Barents Sea bin_accuracy 0.874963 0.848936 0.937071 0.891255
iou 0.185126 0.331170 0.647688 0.420801
mae 0.130236 0.151377 0.067575 0.113846
mse 0.053554 0.106431 0.028444 0.064654
psnr 12.712070 9.729317 15.460110 11.894089
rmse 0.231418 0.326238 0.168653 0.254271
ssim 0.540464 0.609196 0.696043 0.618139
Chukchi Sea bin_accuracy 0.656515 0.675528 0.947459 0.789110
iou 0.126601 0.364351 0.865943 0.585862
mae 0.269926 0.300754 0.069100 0.198657
mse 0.124306 0.246038 0.023475 0.125997
psnr 9.055069 6.089983 16.293947 8.996499
rmse 0.352571 0.496022 0.153215 0.354958
ssim 0.405798 0.385161 0.651510 0.449680
Kara Sea bin_accuracy 0.801598 0.797711 0.939550 0.844245
iou 0.282630 0.412451 0.785852 0.559398
mae 0.162785 0.185920 0.065702 0.149524
mse 0.070723 0.136968 0.025262 0.092185
psnr 11.504373 8.633821 15.975368 10.358038
rmse 0.265939 0.370091 0.158939 0.303539
ssim 0.604080 0.590542 0.725831 0.589535
Laptev Sea bin_accuracy 0.839829 0.863018 0.964288 0.897629
iou 0.387533 0.534633 0.859092 0.683309
mae 0.115111 0.122237 0.043340 0.094628
mse 0.051770 0.094377 0.015273 0.066438
psnr 12.859248 10.251351 18.160892 11.784326
rmse 0.227529 0.307208 0.123582 0.257630
ssim 0.782073 0.746823 0.837163 0.802543
Sea of Japan bin_accuracy 0.994356 0.989473 0.994356 0.995731
iou 0.000000 0.035046 0.000000 0.000000
mae 0.013824 0.016332 0.009841 0.008582
mse 0.004467 0.009577 0.005990 0.004908
psnr 23.499943 20.187490 22.225956 22.945065
rmse 0.066835 0.097865 0.077393 0.069567
ssim 0.841847 0.879064 0.922021 0.919443

How to contribute?

We welcome bug reports, feature requests, and pull requests. This project uses uv for Python version, dependency, and project management.

Development

We try to maintain good practices of readable open source code. Therefore, if you want to participate in the development and open your pool request, pay attention to the following points:

  • To install the project with all development dependencies, run:

    uv sync --locked --all-extras --dev --group=scripts
    

    You can also use pip in your own Python environment, but using uv is the preferred way cause of possible dependency resolve problems.

    # dev version
    pip install -e ".[dev]"
    # scripts version for running examples
    pip install -e ".[scripts]"
    

  • Before committing or pushing changes run the formatters from the repository root:

    uvx isort src tests && uvx black src tests
    
  • To run tests locally with coverage enabled:

    uv run pytest --cov=. --cov-branch tests
    
  • To add new dependencies, run:

    # prod version
    uv add <new-package>
    # dev version
    uv add --dev <new-package>
    # scripts version
    uv add --group=scripts <new-package>
    
  • To buid and run docs locally, run:

    uv run pdoc --math -d google --no-include-undocumented -t .doc/ aiice
    
  • To run any debug scripts with the project env, run:

    uv run <script.py> --group=scripts
    
  • To run Jupyter notebooks with the project env, run:

    uv run --group=scripts --with jupyter jupyter lab    
    

General tips

  • Prefer contributing via forks, especially for external contributors
  • Give only appropriate names to commits / issues / pull-requests

Release process

Despite the fact that the framework is very small, we want to maintain its consistency. The release procedure looks like this:

  • pull-request is approved by maintainers and merged with squashing commits
  • a new tag is being released to the github repository and pypi with GitHub Actions

Documentation

 1"""
 2.. include:: ../../README.md
 3   :start-line: 1
 4<!-- The comment bellow is required as a marker for pdoc. See .doc/module.html.jinja2 -->
 5<!-- MAIN_README_PDOC -->
 6.. include:: ../../CONTRIBUTE.md
 7# Documentation
 8"""
 9
10from aiice import core, loader, metrics, preprocess
11from aiice.benchmark import AIICE
12
13# visible modules to pdoc
14__all__ = ["AIICE", "core", "loader", "metrics", "preprocess"]
15
16from importlib.metadata import PackageNotFoundError, version
17
18try:
19    __version__ = version("aiice-bench")
20except PackageNotFoundError:
21    __version__ = "0.0.0"
class AIICE:
 20class AIICE:
 21    """
 22    High-level interface for loading Arctic ice data, preparing datasets, and benchmarking models.
 23
 24    This class provides a simple API to:
 25    1. Load historical ice data within a specified date range (see `aiice.loader.Loader`)
 26    2. Convert the data into sliding-window datasets (see `aiice.preprocess.SlidingWindowDataset`)
 27    3. Create a PyTorch DataLoader for batch processing
 28    4. Benchmark any PyTorch model on the OSI-SAF dataset with specified metrics
 29
 30    Args:
 31        pre_history_len (int): Number of past time steps to include in each input sample (X).
 32        forecast_len (int): Number of future time steps to predict (Y) in each sample.
 33        batch_size (int, optional): Batch size for the DataLoader. Defaults to 16.
 34        start (date | str | None, optional): Start date of the data to load. If None, defaults to the earliest available data.
 35        end (date | str | None, optional): End date of the data to load. If None, defaults to the latest available data.
 36        step (int | None, optional): Step in days between data points. Defaults to 1 if not provided.
 37        sea (str, optional): Name of the sea (e.g., "Barents Sea"). Check `Loader.seas` for available ones.
 38        threshold (float | None, optional): Threshold for binarizing the target Y. Values above threshold are set to 1, below or equal set to 0. Defaults to None.
 39        x_binarize (bool, optional): Whether to apply the same threshold binarization to input X. Defaults to False.
 40        threads (int, optional): Number of parallel download threads. You can reduce this value in case of rate limiting HuggingFace API errors. Defaults to 16.
 41        device (str | None, optional): Device to place tensors on ("cpu", "cuda", etc.). If None, uses PyTorch default device.
 42
 43    Example:
 44        >>> aiice = AIICE(pre_history_len=30, forecast_len=7, batch_size=32, start="2022-01-01", end="2022-12-31")
 45        >>> model = MyModel()
 46        >>> results = aiice.bench(model, metrics={"mae", "psnr"})
 47    """
 48
 49    def __init__(
 50        self,
 51        pre_history_len: int,
 52        forecast_len: int,
 53        batch_size: int = 16,
 54        start: date | str | None = None,
 55        end: date | str | None = None,
 56        step: int | None = None,
 57        sea: str | None = None,
 58        threshold: float | None = None,
 59        x_binarize: bool = False,
 60        threads: int = 16,
 61        device: str | None = None,
 62    ):
 63        self._device = device
 64        self._sea = sea
 65
 66        raw_data = Loader().get(
 67            start=start,
 68            end=end,
 69            step=step,
 70            sea=sea,
 71            threads=threads,
 72            tensor_out=True,
 73            idx_out=True,
 74        )
 75
 76        indices = raw_data[0]
 77        matrices = raw_data[1]
 78
 79        dataset = SlidingWindowDataset(
 80            data=matrices,
 81            idx=indices,
 82            pre_history_len=pre_history_len,
 83            forecast_len=forecast_len,
 84            threshold=threshold,
 85            x_binarize=x_binarize,
 86            device=self._device,
 87        )
 88
 89        self._dataloader = DataLoader(
 90            dataset=dataset,
 91            batch_size=batch_size,
 92            collate_fn=self._default_collate_fn,
 93        )
 94
 95    def bench(
 96        self,
 97        model: nn.Module,
 98        metrics: dict[str, MetricFn] | list[str] | None = None,
 99        path: str | None = None,
100        detailed: bool = True,
101        plot_workers: int = 4,
102        fps: int = 2,
103    ) -> dict[str, list[float]]:
104        """
105        Run benchmarking evaluation of a model on the prepared dataset.
106
107        The method iterates over the internal DataLoader, generates model
108        predictions, computes evaluation metrics, and optionally produces
109        visualization GIFs comparing ground truth and predicted forecasts.
110
111        When `path` is provided, visualization generation is executed
112        asynchronously using a thread pool so that plotting does not block
113        model inference.
114
115        Args:
116            model (nn.Module):
117                PyTorch model used to generate predictions. The model is expected
118                to accept inputs `x` with shape `(batch, pre_history_len, ...)`
119                and return predictions compatible with the selected metrics.
120
121            metrics (dict[str, MetricFn] | list[str], optional):
122                Metrics to compute during evaluation. If a list of metric names is
123                provided, the metrics are resolved from the built-in registry.
124                If `None`, default metrics are used.
125                See `aiice.metrics.Evaluator` for details.
126
127            path (str, optional):
128                Directory where forecast visualizations will be saved.
129                If provided, each sample in the dataset will produce a GIF
130                animation showing the forecast horizon, comparing ground truth
131                and model predictions frame by frame.
132
133                The files are named: `<start_forecast_date>_<end_forecast_date>.gif`
134                If `None`, visualization generation is skipped.
135
136            detailed (bool, optional):
137                If True, returns full statistics for each metric like
138                mean, last value, count, min, and max.
139                If False, returns only the mean value per metric.
140
141            plot_workers (int, optional):
142                Number of worker threads used for asynchronous plot generation.
143                Increasing this value can speed up visualization when many samples
144                are processed. Defaults to 4.
145
146            fps (int, optional):
147                Frames per second of the generated GIF animations. Defaults to 2.
148
149        Returns:
150            dict[str, list[float]]:
151                Aggregated metric results returned by the evaluator.
152        """
153        if path is not None:
154            os.makedirs(path, exist_ok=True)
155            executor = ThreadPoolExecutor(max_workers=plot_workers)
156            futures = []
157
158        evaluator = Evaluator(metrics=metrics, accumulate=True)
159
160        model.eval()
161        with torch.no_grad():
162            for batch in tqdm(self._dataloader, desc="Prediction"):
163                dates, x, y = batch
164                x, y = x.to(self._device), y.to(self._device)
165
166                pred = model(x)
167                evaluator.eval(y, pred)
168
169                if path is None:
170                    continue
171
172                futures.append(
173                    executor.submit(
174                        self._save_batch_plot,
175                        sea=self._sea,
176                        path=path,
177                        dates=dates,
178                        y=y.detach().cpu().numpy(),
179                        pred=pred.detach().cpu().numpy(),
180                        fps=fps,
181                    )
182                )
183
184        if path is not None:
185            for f in tqdm(futures, desc="Saving plots"):
186                f.result()
187            executor.shutdown(wait=True)
188
189        return evaluator.report(detailed=detailed)
190
191    @staticmethod
192    def _save_batch_plot(
193        sea: str | None,
194        path: str,
195        dates: list[list[date]],
196        y: np.ndarray,
197        pred: np.ndarray,
198        fps: int,
199    ) -> None:
200        """
201        Generate GIF visualizations for a batch of forecast samples.
202
203        For each sample in the batch, a GIF animation is created showing
204        the temporal evolution of the forecast horizon. Each frame displays
205        a side-by-side comparison between the ground truth ice map and the
206        model prediction for the corresponding forecast date.
207
208        The resulting GIF file is saved to `path` with the name: `<start_forecast_date>_<end_forecast_date>.gif`
209        where the dates correspond to the forecast window of the sample.
210        """
211        matplotlib.use("Agg")
212
213        batch_size, forecast_len = y.shape[:2]
214        for i in range(batch_size):
215
216            start_date = dates[i][-forecast_len].strftime("%d-%m-%Y")
217            end_date = dates[i][-1].strftime("%d-%m-%Y")
218
219            save_path = os.path.join(path, f"{start_date}_{end_date}.gif")
220            fig, axes = plt.subplots(1, 2, figsize=(8, 4))
221
222            im_gt = axes[0].imshow(y[i, 0])
223            axes[0].set_title("Ground Truth")
224            axes[0].axis("off")
225
226            im_pred = axes[1].imshow(pred[i, 0])
227            axes[1].set_title("Prediction")
228            axes[1].axis("off")
229
230            frames = []
231            for j in range(forecast_len):
232
233                im_gt.set_data(y[i, j])
234                im_pred.set_data(pred[i, j])
235
236                forecast_date = dates[i][-forecast_len + j]
237                if sea is None:
238                    fig.suptitle(f"Forecast: {forecast_date.strftime('%d-%m-%Y')}")
239                else:
240                    fig.suptitle(
241                        f"{sea} | Forecast: {forecast_date.strftime('%d-%m-%Y')}"
242                    )
243
244                fig.canvas.draw()
245                frame = np.asarray(fig.canvas.buffer_rgba())[:, :, :3].copy()
246                frames.append(frame)
247
248            plt.close(fig)
249            imageio.mimsave(save_path, frames, duration=1 / fps, loop=0)
250
251    @staticmethod
252    def _default_collate_fn(
253        batch: list[tuple[list[date], torch.Tensor, torch.Tensor]],
254    ) -> tuple[list[list[date]], torch.Tensor, torch.Tensor]:
255        """
256        Collates SlidingWindow dataset samples into a batch
257        input  -> batch of samples
258        output -> batched tensors + list of date sequences
259
260        Example:
261        ```
262        d1 = [date1...date2]
263        x1.shape = (T, H, W)
264        y1.shape = (H, W)
265
266        batch = [
267            (d1, x1, y1),
268            (d2, x2, y2)
269        ]
270
271        Output:
272            dates -> [d1, d2]
273            x     -> torch.Tensor (B, T, H, W)
274            y     -> torch.Tensor (B, H, W)
275        ```
276        """
277        dates, x, y = zip(*batch)
278        return list(dates), torch.stack(x), torch.stack(y)

High-level interface for loading Arctic ice data, preparing datasets, and benchmarking models.

This class provides a simple API to:

  1. Load historical ice data within a specified date range (see aiice.loader.Loader)
  2. Convert the data into sliding-window datasets (see aiice.preprocess.SlidingWindowDataset)
  3. Create a PyTorch DataLoader for batch processing
  4. Benchmark any PyTorch model on the OSI-SAF dataset with specified metrics
Arguments:
  • pre_history_len (int): Number of past time steps to include in each input sample (X).
  • forecast_len (int): Number of future time steps to predict (Y) in each sample.
  • batch_size (int, optional): Batch size for the DataLoader. Defaults to 16.
  • start (date | str | None, optional): Start date of the data to load. If None, defaults to the earliest available data.
  • end (date | str | None, optional): End date of the data to load. If None, defaults to the latest available data.
  • step (int | None, optional): Step in days between data points. Defaults to 1 if not provided.
  • sea (str, optional): Name of the sea (e.g., "Barents Sea"). Check Loader.seas for available ones.
  • threshold (float | None, optional): Threshold for binarizing the target Y. Values above threshold are set to 1, below or equal set to 0. Defaults to None.
  • x_binarize (bool, optional): Whether to apply the same threshold binarization to input X. Defaults to False.
  • threads (int, optional): Number of parallel download threads. You can reduce this value in case of rate limiting HuggingFace API errors. Defaults to 16.
  • device (str | None, optional): Device to place tensors on ("cpu", "cuda", etc.). If None, uses PyTorch default device.
Example:
>>> aiice = AIICE(pre_history_len=30, forecast_len=7, batch_size=32, start="2022-01-01", end="2022-12-31")
>>> model = MyModel()
>>> results = aiice.bench(model, metrics={"mae", "psnr"})
def bench( self, model: torch.nn.modules.module.Module, metrics: dict[str, Callable[[typing.Sequence, typing.Sequence], float]] | list[str] | None = None, path: str | None = None, detailed: bool = True, plot_workers: int = 4, fps: int = 2) -> dict[str, list[float]]:
 95    def bench(
 96        self,
 97        model: nn.Module,
 98        metrics: dict[str, MetricFn] | list[str] | None = None,
 99        path: str | None = None,
100        detailed: bool = True,
101        plot_workers: int = 4,
102        fps: int = 2,
103    ) -> dict[str, list[float]]:
104        """
105        Run benchmarking evaluation of a model on the prepared dataset.
106
107        The method iterates over the internal DataLoader, generates model
108        predictions, computes evaluation metrics, and optionally produces
109        visualization GIFs comparing ground truth and predicted forecasts.
110
111        When `path` is provided, visualization generation is executed
112        asynchronously using a thread pool so that plotting does not block
113        model inference.
114
115        Args:
116            model (nn.Module):
117                PyTorch model used to generate predictions. The model is expected
118                to accept inputs `x` with shape `(batch, pre_history_len, ...)`
119                and return predictions compatible with the selected metrics.
120
121            metrics (dict[str, MetricFn] | list[str], optional):
122                Metrics to compute during evaluation. If a list of metric names is
123                provided, the metrics are resolved from the built-in registry.
124                If `None`, default metrics are used.
125                See `aiice.metrics.Evaluator` for details.
126
127            path (str, optional):
128                Directory where forecast visualizations will be saved.
129                If provided, each sample in the dataset will produce a GIF
130                animation showing the forecast horizon, comparing ground truth
131                and model predictions frame by frame.
132
133                The files are named: `<start_forecast_date>_<end_forecast_date>.gif`
134                If `None`, visualization generation is skipped.
135
136            detailed (bool, optional):
137                If True, returns full statistics for each metric like
138                mean, last value, count, min, and max.
139                If False, returns only the mean value per metric.
140
141            plot_workers (int, optional):
142                Number of worker threads used for asynchronous plot generation.
143                Increasing this value can speed up visualization when many samples
144                are processed. Defaults to 4.
145
146            fps (int, optional):
147                Frames per second of the generated GIF animations. Defaults to 2.
148
149        Returns:
150            dict[str, list[float]]:
151                Aggregated metric results returned by the evaluator.
152        """
153        if path is not None:
154            os.makedirs(path, exist_ok=True)
155            executor = ThreadPoolExecutor(max_workers=plot_workers)
156            futures = []
157
158        evaluator = Evaluator(metrics=metrics, accumulate=True)
159
160        model.eval()
161        with torch.no_grad():
162            for batch in tqdm(self._dataloader, desc="Prediction"):
163                dates, x, y = batch
164                x, y = x.to(self._device), y.to(self._device)
165
166                pred = model(x)
167                evaluator.eval(y, pred)
168
169                if path is None:
170                    continue
171
172                futures.append(
173                    executor.submit(
174                        self._save_batch_plot,
175                        sea=self._sea,
176                        path=path,
177                        dates=dates,
178                        y=y.detach().cpu().numpy(),
179                        pred=pred.detach().cpu().numpy(),
180                        fps=fps,
181                    )
182                )
183
184        if path is not None:
185            for f in tqdm(futures, desc="Saving plots"):
186                f.result()
187            executor.shutdown(wait=True)
188
189        return evaluator.report(detailed=detailed)

Run benchmarking evaluation of a model on the prepared dataset.

The method iterates over the internal DataLoader, generates model predictions, computes evaluation metrics, and optionally produces visualization GIFs comparing ground truth and predicted forecasts.

When path is provided, visualization generation is executed asynchronously using a thread pool so that plotting does not block model inference.

Arguments:
  • model (nn.Module): PyTorch model used to generate predictions. The model is expected to accept inputs x with shape (batch, pre_history_len, ...) and return predictions compatible with the selected metrics.
  • metrics (dict[str, MetricFn] | list[str], optional): Metrics to compute during evaluation. If a list of metric names is provided, the metrics are resolved from the built-in registry. If None, default metrics are used. See aiice.metrics.Evaluator for details.
  • path (str, optional): Directory where forecast visualizations will be saved. If provided, each sample in the dataset will produce a GIF animation showing the forecast horizon, comparing ground truth and model predictions frame by frame.

    The files are named: <start_forecast_date>_<end_forecast_date>.gif If None, visualization generation is skipped.

  • detailed (bool, optional): If True, returns full statistics for each metric like mean, last value, count, min, and max. If False, returns only the mean value per metric.
  • plot_workers (int, optional): Number of worker threads used for asynchronous plot generation. Increasing this value can speed up visualization when many samples are processed. Defaults to 4.
  • fps (int, optional): Frames per second of the generated GIF animations. Defaults to 2.
Returns:

dict[str, list[float]]: Aggregated metric results returned by the evaluator.