aiice.metrics
1from collections.abc import Callable, Sequence 2from typing import Sequence 3 4import pytorch_msssim 5import torch 6 7from aiice.constants import ( 8 BIN_ACCURACY_METRIC, 9 COUNT_STAT, 10 DEFAULT_SSIM_KERNEL_WINDOW_SIZE, 11 IOU_METRIC, 12 LAST_STAT, 13 MAE_METRIC, 14 MAX_STAT, 15 MEAN_STAT, 16 MIN_STAT, 17 MSE_METRIC, 18 PSNR_METRIC, 19 RMSE_METRIC, 20 SSIM_METRIC, 21) 22from aiice.preprocess import apply_threshold 23 24 25def _as_tensor(y_true: Sequence, y_pred: Sequence, device=None): 26 y_true = torch.as_tensor(y_true, dtype=torch.float32, device=device).detach() 27 y_pred = torch.as_tensor(y_pred, dtype=torch.float32, device=device).detach() 28 return y_true, y_pred 29 30 31def mae(y_true: Sequence, y_pred: Sequence) -> float: 32 """ 33 MAE (mean absolute error) - determines absolute values range coincidence with real data. 34 """ 35 y_true, y_pred = _as_tensor(y_true, y_pred) 36 return torch.abs(y_true - y_pred).mean().item() 37 38 39def mse(y_true: Sequence, y_pred: Sequence) -> float: 40 """ 41 MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences. 42 """ 43 y_true, y_pred = _as_tensor(y_true, y_pred) 44 return ((y_true - y_pred) ** 2).mean().item() 45 46 47def rmse(y_true: Sequence, y_pred: Sequence) -> float: 48 """ 49 RMSE (root mean square error) - determines absolute values range coincidence as MAE 50 but making emphasis on spatial error distribution of prediction. 51 """ 52 y_true, y_pred = _as_tensor(y_true, y_pred) 53 return torch.sqrt(((y_true - y_pred) ** 2).mean()).item() 54 55 56def psnr(y_true: Sequence, y_pred: Sequence) -> float: 57 """ 58 PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts. 59 """ 60 y_true, y_pred = _as_tensor(y_true, y_pred) 61 62 mse_val = torch.mean((y_true - y_pred) ** 2) 63 if mse_val == 0: 64 return float("inf") 65 66 max_val = torch.max(y_true) 67 return (20 * torch.log10(max_val) - 10 * torch.log10(mse_val)).item() 68 69 70def bin_accuracy(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float: 71 """ 72 Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge 73 gives us possibility to compare binary masks of real ice extent and predicted one. 74 """ 75 y_true, y_pred = _as_tensor(y_true, y_pred) 76 77 y_true = apply_threshold(y_true, threshold) 78 y_pred = apply_threshold(y_pred, threshold) 79 80 return (y_true == y_pred).float().mean().item() 81 82 83def ssim(y_true: Sequence, y_pred: Sequence) -> float: 84 """ 85 SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images 86 87 Raises: 88 ValueError: 89 - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]). 90 - If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size) 91 """ 92 spatial_dims = y_true.shape[2:] 93 if any(dim < DEFAULT_SSIM_KERNEL_WINDOW_SIZE for dim in spatial_dims): 94 raise ValueError( 95 f"All spatial dimensions {spatial_dims} must be >= win_size={DEFAULT_SSIM_KERNEL_WINDOW_SIZE}" 96 ) 97 98 y_true, y_pred = _as_tensor(y_true, y_pred) 99 return float(pytorch_msssim.ssim(y_true, y_pred, data_range=1.0)) 100 101 102def iou(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float: 103 """ 104 IoU (Intersection over Union) - measures overlap between binary masks 105 of ground truth and prediction. 106 107 Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality. 108 """ 109 y_true, y_pred = _as_tensor(y_true, y_pred) 110 111 y_true = apply_threshold(y_true, threshold) 112 y_pred = apply_threshold(y_pred, threshold) 113 114 y_true = y_true.view(y_true.size(0), -1) 115 y_pred = y_pred.view(y_pred.size(0), -1) 116 117 intersection = (y_true * y_pred).sum(dim=1) 118 # union = |A| + |B| - |A ∩ B| 119 union = y_true.sum(dim=1) + y_pred.sum(dim=1) - intersection 120 121 eps = 1e-7 122 iou = intersection / (union + eps) 123 return iou.mean().item() 124 125 126MetricFn = Callable[[Sequence, Sequence], float] 127 128 129class Evaluator: 130 """ 131 Compute and aggregate evaluation metrics over multiple evaluation steps. 132 133 Args: 134 metrics (`dict[str, MetricFn]`, `list[str]`, optional): 135 Metrics to use. If a list of strings is provided, metrics are resolved 136 from the built-in registry. If None, default metrics are used. 137 accumulate (`bool`, optional): 138 Whether to accumulate metric values across multiple `eval` calls. Defaults to True. 139 """ 140 141 _metrics_registry: dict[str, MetricFn] = { 142 MAE_METRIC: mae, 143 MSE_METRIC: mse, 144 RMSE_METRIC: rmse, 145 PSNR_METRIC: psnr, 146 BIN_ACCURACY_METRIC: bin_accuracy, 147 SSIM_METRIC: ssim, 148 IOU_METRIC: iou, 149 } 150 151 def __init__( 152 self, 153 metrics: dict[str, MetricFn] | list[str] | None = None, 154 accumulate: bool = True, 155 ): 156 if metrics is None: 157 self._metrics = self._metrics_registry 158 elif isinstance(metrics, list): 159 self._metrics = self._init_metrics(metrics) 160 else: 161 self._metrics = metrics 162 163 self._accumulate = accumulate 164 self._report: dict[str, list[float]] = {k: [] for k in self._metrics} 165 166 def _init_metrics(self, metrics: list[str]) -> dict[str, MetricFn]: 167 result: dict[str, MetricFn] = {} 168 for name in metrics: 169 try: 170 result[name] = self._metrics_registry[name] 171 except KeyError: 172 raise ValueError( 173 f"Unknown metric '{name}', choose from {list(self._metrics_registry.keys())}" 174 ) 175 return result 176 177 @property 178 def metrics(self) -> list[str]: 179 return list(self._metrics.keys()) 180 181 def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]: 182 """ 183 Evaluate all metrics on a single batch or sample and updates the internal 184 report state depending on the ``accumulate`` mode. 185 """ 186 step_result: dict[str, float] = {} 187 188 for name, fn in self._metrics.items(): 189 value = fn(y_true, y_pred) 190 step_result[name] = value 191 192 if self._accumulate: 193 self._report[name].append(value) 194 else: 195 self._report[name] = [value] 196 197 return step_result 198 199 def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]: 200 """ 201 Return aggregated statistics for all evaluated metrics. 202 203 Args: 204 detailed (`bool`, optional): 205 If True, returns full statistics for each metric including: 206 mean, last value, count, min, and max. 207 If False, returns only the mean value per metric. 208 """ 209 summary: dict[str, dict[str, float] | float] = {} 210 for name, values in self._report.items(): 211 if not values: 212 continue 213 214 if detailed: 215 summary[name] = { 216 MEAN_STAT: sum(values) / len(values), 217 LAST_STAT: values[-1], 218 COUNT_STAT: len(values), 219 MIN_STAT: min(values), 220 MAX_STAT: max(values), 221 } 222 else: 223 summary[name] = sum(values) / len(values) 224 225 return summary
32def mae(y_true: Sequence, y_pred: Sequence) -> float: 33 """ 34 MAE (mean absolute error) - determines absolute values range coincidence with real data. 35 """ 36 y_true, y_pred = _as_tensor(y_true, y_pred) 37 return torch.abs(y_true - y_pred).mean().item()
MAE (mean absolute error) - determines absolute values range coincidence with real data.
40def mse(y_true: Sequence, y_pred: Sequence) -> float: 41 """ 42 MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences. 43 """ 44 y_true, y_pred = _as_tensor(y_true, y_pred) 45 return ((y_true - y_pred) ** 2).mean().item()
MSE (mean squared error) - similar to MAE but emphasizes larger errors by squaring differences.
48def rmse(y_true: Sequence, y_pred: Sequence) -> float: 49 """ 50 RMSE (root mean square error) - determines absolute values range coincidence as MAE 51 but making emphasis on spatial error distribution of prediction. 52 """ 53 y_true, y_pred = _as_tensor(y_true, y_pred) 54 return torch.sqrt(((y_true - y_pred) ** 2).mean()).item()
RMSE (root mean square error) - determines absolute values range coincidence as MAE but making emphasis on spatial error distribution of prediction.
57def psnr(y_true: Sequence, y_pred: Sequence) -> float: 58 """ 59 PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts. 60 """ 61 y_true, y_pred = _as_tensor(y_true, y_pred) 62 63 mse_val = torch.mean((y_true - y_pred) ** 2) 64 if mse_val == 0: 65 return float("inf") 66 67 max_val = torch.max(y_true) 68 return (20 * torch.log10(max_val) - 10 * torch.log10(mse_val)).item()
PSNR (peak signal-to-noise ratio) - reflects noise and distortion level on predicted images identifying artifacts.
71def bin_accuracy(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float: 72 """ 73 Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge 74 gives us possibility to compare binary masks of real ice extent and predicted one. 75 """ 76 y_true, y_pred = _as_tensor(y_true, y_pred) 77 78 y_true = apply_threshold(y_true, threshold) 79 y_pred = apply_threshold(y_pred, threshold) 80 81 return (y_true == y_pred).float().mean().item()
Binary accuracy - binarization of ice concentration continuous field with threshold which causing the presence of an ice edge gives us possibility to compare binary masks of real ice extent and predicted one.
84def ssim(y_true: Sequence, y_pred: Sequence) -> float: 85 """ 86 SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images 87 88 Raises: 89 ValueError: 90 - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]). 91 - If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size) 92 """ 93 spatial_dims = y_true.shape[2:] 94 if any(dim < DEFAULT_SSIM_KERNEL_WINDOW_SIZE for dim in spatial_dims): 95 raise ValueError( 96 f"All spatial dimensions {spatial_dims} must be >= win_size={DEFAULT_SSIM_KERNEL_WINDOW_SIZE}" 97 ) 98 99 y_true, y_pred = _as_tensor(y_true, y_pred) 100 return float(pytorch_msssim.ssim(y_true, y_pred, data_range=1.0))
SSIM (structural similarity index measure) - determines spatial patterns coincidence on predicted and target images
Raises:
- ValueError: - If input tensors are not 4D ([N, C, H, W]) or 5D ([N, C, D, H, W]).
- If any spatial or temporal dimension is smaller than 11 (minimum SSIM kernel window size)
103def iou(y_true: Sequence, y_pred: Sequence, threshold: float = 0.15) -> float: 104 """ 105 IoU (Intersection over Union) - measures overlap between binary masks 106 of ground truth and prediction. 107 108 Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality. 109 """ 110 y_true, y_pred = _as_tensor(y_true, y_pred) 111 112 y_true = apply_threshold(y_true, threshold) 113 y_pred = apply_threshold(y_pred, threshold) 114 115 y_true = y_true.view(y_true.size(0), -1) 116 y_pred = y_pred.view(y_pred.size(0), -1) 117 118 intersection = (y_true * y_pred).sum(dim=1) 119 # union = |A| + |B| - |A ∩ B| 120 union = y_true.sum(dim=1) + y_pred.sum(dim=1) - intersection 121 122 eps = 1e-7 123 iou = intersection / (union + eps) 124 return iou.mean().item()
IoU (Intersection over Union) - measures overlap between binary masks of ground truth and prediction.
Similar to bin_accuracy but focuses on overlap quality instead of per-pixel equality.
130class Evaluator: 131 """ 132 Compute and aggregate evaluation metrics over multiple evaluation steps. 133 134 Args: 135 metrics (`dict[str, MetricFn]`, `list[str]`, optional): 136 Metrics to use. If a list of strings is provided, metrics are resolved 137 from the built-in registry. If None, default metrics are used. 138 accumulate (`bool`, optional): 139 Whether to accumulate metric values across multiple `eval` calls. Defaults to True. 140 """ 141 142 _metrics_registry: dict[str, MetricFn] = { 143 MAE_METRIC: mae, 144 MSE_METRIC: mse, 145 RMSE_METRIC: rmse, 146 PSNR_METRIC: psnr, 147 BIN_ACCURACY_METRIC: bin_accuracy, 148 SSIM_METRIC: ssim, 149 IOU_METRIC: iou, 150 } 151 152 def __init__( 153 self, 154 metrics: dict[str, MetricFn] | list[str] | None = None, 155 accumulate: bool = True, 156 ): 157 if metrics is None: 158 self._metrics = self._metrics_registry 159 elif isinstance(metrics, list): 160 self._metrics = self._init_metrics(metrics) 161 else: 162 self._metrics = metrics 163 164 self._accumulate = accumulate 165 self._report: dict[str, list[float]] = {k: [] for k in self._metrics} 166 167 def _init_metrics(self, metrics: list[str]) -> dict[str, MetricFn]: 168 result: dict[str, MetricFn] = {} 169 for name in metrics: 170 try: 171 result[name] = self._metrics_registry[name] 172 except KeyError: 173 raise ValueError( 174 f"Unknown metric '{name}', choose from {list(self._metrics_registry.keys())}" 175 ) 176 return result 177 178 @property 179 def metrics(self) -> list[str]: 180 return list(self._metrics.keys()) 181 182 def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]: 183 """ 184 Evaluate all metrics on a single batch or sample and updates the internal 185 report state depending on the ``accumulate`` mode. 186 """ 187 step_result: dict[str, float] = {} 188 189 for name, fn in self._metrics.items(): 190 value = fn(y_true, y_pred) 191 step_result[name] = value 192 193 if self._accumulate: 194 self._report[name].append(value) 195 else: 196 self._report[name] = [value] 197 198 return step_result 199 200 def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]: 201 """ 202 Return aggregated statistics for all evaluated metrics. 203 204 Args: 205 detailed (`bool`, optional): 206 If True, returns full statistics for each metric including: 207 mean, last value, count, min, and max. 208 If False, returns only the mean value per metric. 209 """ 210 summary: dict[str, dict[str, float] | float] = {} 211 for name, values in self._report.items(): 212 if not values: 213 continue 214 215 if detailed: 216 summary[name] = { 217 MEAN_STAT: sum(values) / len(values), 218 LAST_STAT: values[-1], 219 COUNT_STAT: len(values), 220 MIN_STAT: min(values), 221 MAX_STAT: max(values), 222 } 223 else: 224 summary[name] = sum(values) / len(values) 225 226 return summary
Compute and aggregate evaluation metrics over multiple evaluation steps.
Arguments:
- metrics (
dict[str, MetricFn],list[str], optional): Metrics to use. If a list of strings is provided, metrics are resolved from the built-in registry. If None, default metrics are used. - accumulate (
bool, optional): Whether to accumulate metric values across multipleevalcalls. Defaults to True.
182 def eval(self, y_true: Sequence, y_pred: Sequence) -> dict[str, float]: 183 """ 184 Evaluate all metrics on a single batch or sample and updates the internal 185 report state depending on the ``accumulate`` mode. 186 """ 187 step_result: dict[str, float] = {} 188 189 for name, fn in self._metrics.items(): 190 value = fn(y_true, y_pred) 191 step_result[name] = value 192 193 if self._accumulate: 194 self._report[name].append(value) 195 else: 196 self._report[name] = [value] 197 198 return step_result
Evaluate all metrics on a single batch or sample and updates the internal
report state depending on the accumulate mode.
200 def report(self, detailed: bool = True) -> dict[str, dict[str, float] | float]: 201 """ 202 Return aggregated statistics for all evaluated metrics. 203 204 Args: 205 detailed (`bool`, optional): 206 If True, returns full statistics for each metric including: 207 mean, last value, count, min, and max. 208 If False, returns only the mean value per metric. 209 """ 210 summary: dict[str, dict[str, float] | float] = {} 211 for name, values in self._report.items(): 212 if not values: 213 continue 214 215 if detailed: 216 summary[name] = { 217 MEAN_STAT: sum(values) / len(values), 218 LAST_STAT: values[-1], 219 COUNT_STAT: len(values), 220 MIN_STAT: min(values), 221 MAX_STAT: max(values), 222 } 223 else: 224 summary[name] = sum(values) / len(values) 225 226 return summary
Return aggregated statistics for all evaluated metrics.
Arguments:
- detailed (
bool, optional): If True, returns full statistics for each metric including: mean, last value, count, min, and max. If False, returns only the mean value per metric.