aiice.metrics

  1from collections.abc import Callable, Sequence
  2from typing import Sequence
  3
  4import pytorch_msssim
  5import torch
  6
  7from aiice.constants import (
  8    BIN_ACCURACY_METRIC,
  9    COUNT_STAT,
 10    DEFAULT_SSIM_KERNEL_WINDOW_SIZE,
 11    IOU_METRIC,
 12    LAST_STAT,
 13    MAE_METRIC,
 14    MAX_STAT,
 15    MEAN_STAT,
 16    MIN_STAT,
 17    MSE_METRIC,
 18    PSNR_METRIC,
 19    RMSE_METRIC,
 20    SSIM_METRIC,
 21)
 22from aiice.preprocess import apply_threshold
 23
 24
 25def _as_tensor(y_true: Sequence, y_pred: Sequence, device=None):
 26    y_true = torch.as_tensor(y_true, dtype=torch.float32, device=device).detach()
 27    y_pred = torch.as_tensor(y_pred, dtype=torch.float32, device=device).detach()
 28    return y_true, y_pred
 29
 30
 31def mae(y_true: Sequence, y_pred: Sequence) -> float:
 32    """
 33    MAE (mean absolute error) - determines absolute values range coincidence with real data.
 34    """
 35    y_true, y_pred = _as_tensor(y_true, y_pred)
 36    return torch.abs(y_true - y_pred).mean().item()
 37
 38
 39def mse(y_true: Sequence, y_pred: Sequence) -> float:
 40    """
 41    MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences.
 42    """
 43    y_true, y_pred = _as_tensor(y_true, y_pred)
 44    return ((y_true - y_pred) ** 2).mean().item()
 45
 46
 47def rmse(y_true: Sequence, y_pred: Sequence) -> float:
 48    """
 49    RMSE (root mean square error) - determines absolute values range coincidence as MAE
 50    but making emphasis on spatial error distribution of prediction.
 51    """
 52    y_true, y_pred = _as_tensor(y_true, y_pred)
 53    return torch.sqrt(((y_true - y_pred) ** 2).mean()).item()
 54
 55
 56def psnr(y_true: Sequence, y_pred: Sequence) -> float:
 57    """
 58    PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts.
 59    """
 60    y_true, y_pred = _as_tensor(y_true, y_pred)
 61
 62    mse_val = torch.mean((y_true - y_pred) ** 2)
 63    if mse_val == 0:
 64        return float("inf")
 65
 66    max_val = torch.max(y_true)
 67    return (20 * torch.log10(max_val) - 10 * torch.log10(mse_val)).item()
 68
 69
 70def bin_accuracy(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
 71    """
 72    Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge
 73    gives us possibility to compare binary masks of real ice extent and predicted one.
 74    """
 75    y_true, y_pred = _as_tensor(y_true, y_pred)
 76
 77    y_true = apply_threshold(y_true, threshold)
 78    y_pred = apply_threshold(y_pred, threshold)
 79
 80    return (y_true == y_pred).float().mean().item()
 81
 82
 83def ssim(y_true: Sequence, y_pred: Sequence) -> float:
 84    """
 85    SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images
 86
 87    Raises:
 88        ValueError:
 89            - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]).
 90            - If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size)
 91    """
 92    spatial_dims = y_true.shape[2:]
 93    if any(dim < DEFAULT_SSIM_KERNEL_WINDOW_SIZE for dim in spatial_dims):
 94        raise ValueError(
 95            f"All spatial dimensions {spatial_dims} must be >= win_size={DEFAULT_SSIM_KERNEL_WINDOW_SIZE}"
 96        )
 97
 98    y_true, y_pred = _as_tensor(y_true, y_pred)
 99    return float(pytorch_msssim.ssim(y_true, y_pred, data_range=1.0))
100
101
102def iou(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
103    """
104    IoU (Intersection over Union) - measures overlap between binary masks
105    of ground truth and prediction.
106
107    Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality.
108    """
109    y_true, y_pred = _as_tensor(y_true, y_pred)
110
111    y_true = apply_threshold(y_true, threshold)
112    y_pred = apply_threshold(y_pred, threshold)
113
114    y_true = y_true.view(y_true.size(0), -1)
115    y_pred = y_pred.view(y_pred.size(0), -1)
116
117    intersection = (y_true * y_pred).sum(dim=1)
118    # union = |A| + |B| - |A ∩ B|
119    union = y_true.sum(dim=1) + y_pred.sum(dim=1) - intersection
120
121    eps = 1e-7
122    iou = intersection / (union + eps)
123    return iou.mean().item()
124
125
126MetricFn = Callable[[Sequence, Sequence], float]
127
128
129class Evaluator:
130    """
131    Compute and aggregate evaluation metrics over multiple evaluation steps.
132
133    Args:
134        metrics (`dict[str, MetricFn]`, `list[str]`, optional):
135            Metrics to use. If a list of strings is provided, metrics are resolved
136            from the built-in registry. If None, default metrics are used.
137        accumulate (`bool`, optional):
138            Whether to accumulate metric values across multiple `eval` calls. Defaults to True.
139    """
140
141    _metrics_registry: dict[str, MetricFn] = {
142        MAE_METRIC: mae,
143        MSE_METRIC: mse,
144        RMSE_METRIC: rmse,
145        PSNR_METRIC: psnr,
146        BIN_ACCURACY_METRIC: bin_accuracy,
147        SSIM_METRIC: ssim,
148        IOU_METRIC: iou,
149    }
150
151    def __init__(
152        self,
153        metrics: dict[str, MetricFn] | list[str] | None = None,
154        accumulate: bool = True,
155    ):
156        if metrics is None:
157            self._metrics = self._metrics_registry
158        elif isinstance(metrics, list):
159            self._metrics = self._init_metrics(metrics)
160        else:
161            self._metrics = metrics
162
163        self._accumulate = accumulate
164        self._report: dict[str, list[float]] = {k: [] for k in self._metrics}
165
166    def _init_metrics(self, metrics: list[str]) -> dict[str, MetricFn]:
167        result: dict[str, MetricFn] = {}
168        for name in metrics:
169            try:
170                result[name] = self._metrics_registry[name]
171            except KeyError:
172                raise ValueError(
173                    f"Unknown metric '{name}', choose from {list(self._metrics_registry.keys())}"
174                )
175        return result
176
177    @property
178    def metrics(self) -> list[str]:
179        return list(self._metrics.keys())
180
181    def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]:
182        """
183        Evaluate all metrics on a single batch or sample and updates the internal
184        report state depending on the ``accumulate`` mode.
185        """
186        step_result: dict[str, float] = {}
187
188        for name, fn in self._metrics.items():
189            value = fn(y_true, y_pred)
190            step_result[name] = value
191
192            if self._accumulate:
193                self._report[name].append(value)
194            else:
195                self._report[name] = [value]
196
197        return step_result
198
199    def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]:
200        """
201        Return aggregated statistics for all evaluated metrics.
202
203        Args:
204            detailed (`bool`, optional):
205                If True, returns full statistics for each metric including:
206                mean, last value, count, min, and max.
207                If False, returns only the mean value per metric.
208        """
209        summary: dict[str, dict[str, float] | float] = {}
210        for name, values in self._report.items():
211            if not values:
212                continue
213
214            if detailed:
215                summary[name] = {
216                    MEAN_STAT: sum(values) / len(values),
217                    LAST_STAT: values[-1],
218                    COUNT_STAT: len(values),
219                    MIN_STAT: min(values),
220                    MAX_STAT: max(values),
221                }
222            else:
223                summary[name] = sum(values) / len(values)
224
225        return summary
def mae(y_true: Sequence, y_pred: Sequence) -> float:
32def mae(y_true: Sequence, y_pred: Sequence) -> float:
33    """
34    MAE (mean absolute error) - determines absolute values range coincidence with real data.
35    """
36    y_true, y_pred = _as_tensor(y_true, y_pred)
37    return torch.abs(y_true - y_pred).mean().item()

MAE (mean absolute error) - determines absolute values range coincidence with real data.

def mse(y_true: Sequence, y_pred: Sequence) -> float:
40def mse(y_true: Sequence, y_pred: Sequence) -> float:
41    """
42    MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences.
43    """
44    y_true, y_pred = _as_tensor(y_true, y_pred)
45    return ((y_true - y_pred) ** 2).mean().item()

MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences.

def rmse(y_true: Sequence, y_pred: Sequence) -> float:
48def rmse(y_true: Sequence, y_pred: Sequence) -> float:
49    """
50    RMSE (root mean square error) - determines absolute values range coincidence as MAE
51    but making emphasis on spatial error distribution of prediction.
52    """
53    y_true, y_pred = _as_tensor(y_true, y_pred)
54    return torch.sqrt(((y_true - y_pred) ** 2).mean()).item()

RMSE (root mean square error) - determines absolute values range coincidence as MAE but making emphasis on spatial error distribution of prediction.

def psnr(y_true: Sequence, y_pred: Sequence) -> float:
57def psnr(y_true: Sequence, y_pred: Sequence) -> float:
58    """
59    PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts.
60    """
61    y_true, y_pred = _as_tensor(y_true, y_pred)
62
63    mse_val = torch.mean((y_true - y_pred) ** 2)
64    if mse_val == 0:
65        return float("inf")
66
67    max_val = torch.max(y_true)
68    return (20 * torch.log10(max_val) - 10 * torch.log10(mse_val)).item()

PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts.

def bin_accuracy(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
71def bin_accuracy(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
72    """
73    Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge
74    gives us possibility to compare binary masks of real ice extent and predicted one.
75    """
76    y_true, y_pred = _as_tensor(y_true, y_pred)
77
78    y_true = apply_threshold(y_true, threshold)
79    y_pred = apply_threshold(y_pred, threshold)
80
81    return (y_true == y_pred).float().mean().item()

Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge gives us possibility to compare binary masks of real ice extent and predicted one.

def ssim(y_true: Sequence, y_pred: Sequence) -> float:
 84def ssim(y_true: Sequence, y_pred: Sequence) -> float:
 85    """
 86    SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images
 87
 88    Raises:
 89        ValueError:
 90            - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]).
 91            - If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size)
 92    """
 93    spatial_dims = y_true.shape[2:]
 94    if any(dim < DEFAULT_SSIM_KERNEL_WINDOW_SIZE for dim in spatial_dims):
 95        raise ValueError(
 96            f"All spatial dimensions {spatial_dims} must be >= win_size={DEFAULT_SSIM_KERNEL_WINDOW_SIZE}"
 97        )
 98
 99    y_true, y_pred = _as_tensor(y_true, y_pred)
100    return float(pytorch_msssim.ssim(y_true, y_pred, data_range=1.0))

SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images

Raises:
  • ValueError: - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]).
    • If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size)
def iou(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
103def iou(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float:
104    """
105    IoU (Intersection over Union) - measures overlap between binary masks
106    of ground truth and prediction.
107
108    Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality.
109    """
110    y_true, y_pred = _as_tensor(y_true, y_pred)
111
112    y_true = apply_threshold(y_true, threshold)
113    y_pred = apply_threshold(y_pred, threshold)
114
115    y_true = y_true.view(y_true.size(0), -1)
116    y_pred = y_pred.view(y_pred.size(0), -1)
117
118    intersection = (y_true * y_pred).sum(dim=1)
119    # union = |A| + |B| - |A ∩ B|
120    union = y_true.sum(dim=1) + y_pred.sum(dim=1) - intersection
121
122    eps = 1e-7
123    iou = intersection / (union + eps)
124    return iou.mean().item()

IoU (Intersection over Union) - measures overlap between binary masks of ground truth and prediction.

Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality.

class Evaluator:
130class Evaluator:
131    """
132    Compute and aggregate evaluation metrics over multiple evaluation steps.
133
134    Args:
135        metrics (`dict[str, MetricFn]`, `list[str]`, optional):
136            Metrics to use. If a list of strings is provided, metrics are resolved
137            from the built-in registry. If None, default metrics are used.
138        accumulate (`bool`, optional):
139            Whether to accumulate metric values across multiple `eval` calls. Defaults to True.
140    """
141
142    _metrics_registry: dict[str, MetricFn] = {
143        MAE_METRIC: mae,
144        MSE_METRIC: mse,
145        RMSE_METRIC: rmse,
146        PSNR_METRIC: psnr,
147        BIN_ACCURACY_METRIC: bin_accuracy,
148        SSIM_METRIC: ssim,
149        IOU_METRIC: iou,
150    }
151
152    def __init__(
153        self,
154        metrics: dict[str, MetricFn] | list[str] | None = None,
155        accumulate: bool = True,
156    ):
157        if metrics is None:
158            self._metrics = self._metrics_registry
159        elif isinstance(metrics, list):
160            self._metrics = self._init_metrics(metrics)
161        else:
162            self._metrics = metrics
163
164        self._accumulate = accumulate
165        self._report: dict[str, list[float]] = {k: [] for k in self._metrics}
166
167    def _init_metrics(self, metrics: list[str]) -> dict[str, MetricFn]:
168        result: dict[str, MetricFn] = {}
169        for name in metrics:
170            try:
171                result[name] = self._metrics_registry[name]
172            except KeyError:
173                raise ValueError(
174                    f"Unknown metric '{name}', choose from {list(self._metrics_registry.keys())}"
175                )
176        return result
177
178    @property
179    def metrics(self) -> list[str]:
180        return list(self._metrics.keys())
181
182    def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]:
183        """
184        Evaluate all metrics on a single batch or sample and updates the internal
185        report state depending on the ``accumulate`` mode.
186        """
187        step_result: dict[str, float] = {}
188
189        for name, fn in self._metrics.items():
190            value = fn(y_true, y_pred)
191            step_result[name] = value
192
193            if self._accumulate:
194                self._report[name].append(value)
195            else:
196                self._report[name] = [value]
197
198        return step_result
199
200    def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]:
201        """
202        Return aggregated statistics for all evaluated metrics.
203
204        Args:
205            detailed (`bool`, optional):
206                If True, returns full statistics for each metric including:
207                mean, last value, count, min, and max.
208                If False, returns only the mean value per metric.
209        """
210        summary: dict[str, dict[str, float] | float] = {}
211        for name, values in self._report.items():
212            if not values:
213                continue
214
215            if detailed:
216                summary[name] = {
217                    MEAN_STAT: sum(values) / len(values),
218                    LAST_STAT: values[-1],
219                    COUNT_STAT: len(values),
220                    MIN_STAT: min(values),
221                    MAX_STAT: max(values),
222                }
223            else:
224                summary[name] = sum(values) / len(values)
225
226        return summary

Compute and aggregate evaluation metrics over multiple evaluation steps.

Arguments:
  • metrics (dict[str, MetricFn], list[str], optional): Metrics to use. If a list of strings is provided, metrics are resolved from the built-in registry. If None, default metrics are used.
  • accumulate (bool, optional): Whether to accumulate metric values across multiple eval calls. Defaults to True.
def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]:
182    def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]:
183        """
184        Evaluate all metrics on a single batch or sample and updates the internal
185        report state depending on the ``accumulate`` mode.
186        """
187        step_result: dict[str, float] = {}
188
189        for name, fn in self._metrics.items():
190            value = fn(y_true, y_pred)
191            step_result[name] = value
192
193            if self._accumulate:
194                self._report[name].append(value)
195            else:
196                self._report[name] = [value]
197
198        return step_result

Evaluate all metrics on a single batch or sample and updates the internal report state depending on the accumulate mode.

def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]:
200    def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]:
201        """
202        Return aggregated statistics for all evaluated metrics.
203
204        Args:
205            detailed (`bool`, optional):
206                If True, returns full statistics for each metric including:
207                mean, last value, count, min, and max.
208                If False, returns only the mean value per metric.
209        """
210        summary: dict[str, dict[str, float] | float] = {}
211        for name, values in self._report.items():
212            if not values:
213                continue
214
215            if detailed:
216                summary[name] = {
217                    MEAN_STAT: sum(values) / len(values),
218                    LAST_STAT: values[-1],
219                    COUNT_STAT: len(values),
220                    MIN_STAT: min(values),
221                    MAX_STAT: max(values),
222                }
223            else:
224                summary[name] = sum(values) / len(values)
225
226        return summary

Return aggregated statistics for all evaluated metrics.

Arguments:
  • detailed (bool, optional): If True, returns full statistics for each metric including: mean, last value, count, min, and max. If False, returns only the mean value per metric.