def test_metrics_dict_roc() -> None:
    """
    Test if adding ROC entries to a MetricsDict instance works, and returns the correct AUC.
    """
    # Prepare a vector of predictions and labels. We can compute AUC off those to compare.
    # MetricsDict will get that supplied in 3 chunks, and should return the same AUC value.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    labels = np.array([0, 1.0, 0, 0, 1, 1], dtype=np.float)
    split_length = [3, 2, 1]
    assert sum(split_length) == len(predictions)
    summed = np.cumsum(split_length)
    m = MetricsDict()
    for i, end in enumerate(summed):
        start = 0 if i == 0 else summed[i - 1]
        pred = predictions[start:end]
        label = labels[start:end]
        subject_ids = list(range(len(pred)))
        m.add_predictions(subject_ids, pred, label)
    assert m.has_prediction_entries
    actual_auc = m.get_roc_auc()
    expected_auc = roc_auc_score(labels, predictions)
    assert actual_auc == pytest.approx(expected_auc, 1e-6)
    actual_pr_auc = m.get_pr_auc()
    expected_pr_auc = 0.7111111
    assert actual_pr_auc == pytest.approx(expected_pr_auc, 1e-6)
def test_classification_metrics_avg() -> None:
    hue1 = "H1"
    hue2 = "H2"
    m = MetricsDict(hues=[hue1, hue2], is_classification_metrics=True)
    m.add_metric("foo", 1.0)
    m.add_metric("foo", 2.0)
    # Perfect predictions for hue1, should give AUC == 1.0
    m.add_predictions(["S1", "S2"], np.array([0.0, 1.0]), np.array([0.0, 1.0]), hue=hue1)
    expected_hue1_auc = 1.0
    # Worst possible predictions for hue2, should give AUC == 0.0
    m.add_predictions(["S1", "S2"], np.array([1.0, 0.0]), np.array([0.0, 1.0]), hue=hue2)
    expected_hue2_auc = 0.0
    averaged = m.average(across_hues=False)
    g1_averaged = averaged.values(hue=hue1)
    assert MetricType.AREA_UNDER_ROC_CURVE.value in g1_averaged
    assert g1_averaged[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_hue1_auc]
    assert MetricType.AREA_UNDER_PR_CURVE.value in g1_averaged
    assert MetricType.SUBJECT_COUNT.value in g1_averaged
    assert g1_averaged[MetricType.SUBJECT_COUNT.value] == [2.0]
    default_averaged = averaged.values()
    assert default_averaged == {"foo": [1.5]}
    can_enumerate = list(averaged.enumerate_single_values())
    assert len(can_enumerate) >= 8
    assert can_enumerate[0] == (hue1, MetricType.AREA_UNDER_ROC_CURVE.value, 1.0)
    assert can_enumerate[-1] == (MetricsDict.DEFAULT_HUE_KEY, "foo", 1.5)

    g2_averaged = averaged.values(hue=hue2)
    assert MetricType.AREA_UNDER_ROC_CURVE.value in g2_averaged
    assert g2_averaged[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_hue2_auc]

    averaged_across_hues = m.average(across_hues=True)
    assert averaged_across_hues.get_hue_names() == [MetricsDict.DEFAULT_HUE_KEY]
    assert MetricType.AREA_UNDER_ROC_CURVE.value in averaged_across_hues.values()
    expected_averaged_auc = 0.5 * (expected_hue1_auc + expected_hue2_auc)
    assert averaged_across_hues.values()[MetricType.AREA_UNDER_ROC_CURVE.value] == [expected_averaged_auc]
def test_metrics_dict1() -> None:
    """
    Test insertion of scalar values into a MetricsDict.
    """
    m = MetricsDict()
    assert m.get_hue_names() == [MetricsDict.DEFAULT_HUE_KEY]
    name = "foo"
    v1 = 2.7
    v2 = 3.14
    m.add_metric(name, v1)
    m.add_metric(name, v2)
    assert m.values()[name] == [v1, v2]
    with pytest.raises(ValueError) as ex:
        # noinspection PyTypeChecker
        m.add_metric(name, [1.0])  # type: ignore
    assert "Expected the metric to be a scalar" in str(ex)
    assert m.skip_nan_when_averaging[name] is False
    v3 = 3.0
    name2 = "bar"
    m.add_metric(name2, v3, skip_nan_when_averaging=True)
    assert m.skip_nan_when_averaging[name2] is True
    # Expected average: Metric "foo" averages over two values v1 and v2. For "bar", we only inserted one value anyhow
    average = m.average()
    mean_v1_v2 = mean([v1, v2])
    assert average.values() == {name: [mean_v1_v2], name2: [v3]}
    num_entries = m.num_entries()
    assert num_entries == {name: 2, name2: 1}
def test_metrics_dict_add_integer() -> None:
    """
    Adding a scalar metric where the value is an integer by accident should still store the metric.
    """
    m = MetricsDict()
    m.add_metric("foo", 1)
    assert "foo" in m.values()
    assert m.values()["foo"] == [1.0]
def test_delete_hue() -> None:
    h1 = "a"
    h2 = "b"
    a = MetricsDict(hues=[h1, h2])
    a.add_metric("foo", 1.0, hue=h1)
    a.add_metric("bar", 2.0, hue=h2)
    a.delete_hue(h1)
    assert a.get_hue_names(include_default=False) == [h2]
    assert list(a.enumerate_single_values()) == [(h2, "bar", 2.0)]
def test_delete_metric() -> None:
    """
    Deleting a set of metrics from the dictionary.
    """
    m = MetricsDict()
    m.add_metric(MetricType.LOSS, 1)
    assert m.values()[MetricType.LOSS.value] == [1.0]
    m.delete_metric(MetricType.LOSS)
    assert MetricType.LOSS.value not in m.values()
def test_add_foreground_dice() -> None:
    g1 = "Liver"
    g2 = "Lung"
    ground_truth_ids = [BACKGROUND_CLASS_NAME, g1, g2]
    dice = [0.85, 0.75, 0.55]
    m = MetricsDict(hues=ground_truth_ids)
    for j, ground_truth_id in enumerate(ground_truth_ids):
        m.add_metric(MetricType.DICE, dice[j], hue=ground_truth_id)
    metrics.add_average_foreground_dice(m)
    assert m.get_single_metric(MetricType.DICE) == 0.5 * (dice[1] + dice[2])
def test_metrics_store_mixed_hues() -> None:
    """
    Test to make sure metrics dict is able to handle default and non-default hues
    """
    m = MetricsDict(hues=["A", "B"])
    m.add_metric("foo", 1)
    m.add_metric("foo", 1, hue="B")
    m.add_metric("bar", 2, hue="A")
    assert list(m.enumerate_single_values()) == \
           [('A', 'bar', 2), ('B', 'foo', 1), (MetricsDict.DEFAULT_HUE_KEY, 'foo', 1)]
def test_diagnostics() -> None:
    """
    Test if we can store diagnostic values (no restrictions on data types) in the metrics dict.
    """
    name = "foo"
    value1 = "something"
    value2 = (1, 2, 3)
    m = MetricsDict()
    m.add_diagnostics(name, value1)
    m.add_diagnostics(name, value2)
    assert m.diagnostics == {name: [value1, value2]}
def test_metrics_dict_to_string() -> None:
    """
    Test to make sure metrics dict is able to be stringified correctly
    """
    m = MetricsDict()
    m.add_metric("foo", 1.0)
    m.add_metric("bar", math.pi)
    info_df = pd.DataFrame(columns=MetricsDict.DATAFRAME_COLUMNS)
    info_df = info_df.append({MetricsDict.DATAFRAME_COLUMNS[0]: MetricsDict.DEFAULT_HUE_KEY,
                              MetricsDict.DATAFRAME_COLUMNS[1]: "foo: 1.0000, bar: 3.1416"}, ignore_index=True)
    assert m.to_string() == tabulate_dataframe(info_df)
    assert m.to_string(tabulate=False) == info_df.to_string(index=False)
def test_metrics_dict_roc_degenerate() -> None:
    """
    Test if adding ROC entries to a MetricsDict instance works, if there is only 1 class present.
    """
    # Prepare a vector of predictions and labels. We can compute AUC off those to compare.
    # MetricsDict will get that supplied in 3 chunks, and should return the same AUC value.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    m = MetricsDict()
    subject_ids = list(range(len(predictions)))
    m.add_predictions(subject_ids, predictions, np.ones_like(predictions))
    assert m.has_prediction_entries
    assert m.get_roc_auc() == 1.0
    assert m.get_pr_auc() == 1.0
def test_metrics_dict_average_metrics_averaging() -> None:
    """
    Test if averaging metrics avoid NaN as expected.
    """
    m = MetricsDict()
    metric1 = "foo"
    v1 = 1.0
    m.add_metric(metric1, v1)
    m.add_metric(metric1, np.nan, skip_nan_when_averaging=True)
    metric2 = "bar"
    v2 = 2.0
    m.add_metric(metric2, v2)
    m.add_metric(metric2, np.nan, skip_nan_when_averaging=False)
    average = m.average()
    assert average.values()[metric1] == [v1]
    assert np.isnan(average.values()[metric2])
def test_aggregate_segmentation_metrics() -> None:
    """
    Test how per-epoch segmentation metrics are aggregated to computed foreground dice and voxel count proportions.
    """
    g1 = "Liver"
    g2 = "Lung"
    ground_truth_ids = [BACKGROUND_CLASS_NAME, g1, g2]
    dice = [0.85, 0.75, 0.55]
    voxels_proportion = [0.85, 0.10, 0.05]
    loss = 3.14
    other_metric = 2.71
    m = MetricsDict(hues=ground_truth_ids)
    voxel_count = 200
    # Add 3 values per metric, but such that the averages are back at the value given in dice[i]
    for i in range(3):
        delta = (i - 1) * 0.05
        for j, ground_truth_id in enumerate(ground_truth_ids):
            m.add_metric(MetricType.DICE, dice[j] + delta, hue=ground_truth_id)
            m.add_metric(MetricType.VOXEL_COUNT, int(voxels_proportion[j] * voxel_count), hue=ground_truth_id)
        m.add_metric(MetricType.LOSS, loss + delta)
        m.add_metric("foo", other_metric)
    m.add_diagnostics("foo", "bar")
    aggregate = metrics.aggregate_segmentation_metrics(m)
    assert aggregate.diagnostics == m.diagnostics
    enumerated = list((g, s, v) for g, s, v in aggregate.enumerate_single_values())
    expected = [
        # Dice and voxel count per foreground structure should be retained during averaging
        (g1, MetricType.DICE.value, dice[1]),
        (g1, MetricType.VOXEL_COUNT.value, voxels_proportion[1] * voxel_count),
        # Proportion of foreground voxels is computed during averaging
        (g1, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[1]),
        (g2, MetricType.DICE.value, dice[2]),
        (g2, MetricType.VOXEL_COUNT.value, voxels_proportion[2] * voxel_count),
        (g2, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[2]),
        # Loss is present in the default metrics group, and should be retained.
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.LOSS.value, loss),
        (MetricsDict.DEFAULT_HUE_KEY, "foo", other_metric),
        # Dice averaged across the foreground structures is added during the function call, as is proportion of voxels
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.DICE.value, 0.5 * (dice[1] + dice[2])),
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.PROPORTION_FOREGROUND_VOXELS.value,
         voxels_proportion[1] + voxels_proportion[2]),
    ]
    assert len(enumerated) == len(expected)
    # Numbers won't match up precisely because of rounding during averaging
    for (actual, e) in zip(enumerated, expected):
        assert actual[0:2] == e[0:2]
        assert actual[2] == pytest.approx(e[2])
def test_get_single_metric() -> None:
    h1 = "a"
    m = MetricsDict(hues=[h1])
    m1, v1 = ("foo", 1.0)
    m2, v2 = (MetricType.LOSS, 2.0)
    m.add_metric(m1, v1, hue=h1)
    m.add_metric(m2, v2)
    assert m.get_single_metric(m1, h1) == v1
    assert m.get_single_metric(m2) == v2
    with pytest.raises(KeyError) as ex1:
        m.get_single_metric(m1, "no such hue")
    assert "no such hue" in str(ex1)
    with pytest.raises(KeyError) as ex2:
        m.get_single_metric("no such metric", h1)
    assert "no such metric" in str(ex2)
    m.add_metric(m2, v2)
    with pytest.raises(ValueError) as ex3:
        m.get_single_metric(m2)
    assert "Expected a single entry" in str(ex3)
Ejemplo n.º 15
0
 def __init__(self, model_config: SegmentationModelBase,
              train_val_params: TrainValidateParameters[DeviceAwareModule]):
     """
     Creates a new instance of the class.
     :param model_config: The configuration of a segmentation model.
     :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
     """
     super().__init__(model_config, train_val_params)
     self.example_to_save = np.random.randint(
         0, len(train_val_params.data_loader))
     self.pipeline = SegmentationForwardPass(
         model=self.train_val_params.model,
         model_config=self.model_config,
         batch_size=self.model_config.train_batch_size,
         optimizer=self.train_val_params.optimizer,
         in_training_mode=self.train_val_params.in_training_mode,
         criterion=self.compute_loss,
         gradient_scaler=train_val_params.gradient_scaler)
     self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] +
                                model_config.ground_truth_ids)
def test_metrics_dict_average_additional_metrics() -> None:
    """
    Test if computing the ROC entries and metrics at optimal threshold with MetricsDict.average() works
    as expected and returns the correct values.
    """
    # Prepare a vector of predictions and labels.
    predictions = np.array([0.5, 0.6, 0.1, 0.8, 0.2, 0.9])
    labels = np.array([0, 1.0, 0, 0, 1, 1], dtype=np.float)
    split_length = [3, 2, 1]

    # Get MetricsDict
    assert sum(split_length) == len(predictions)
    summed = np.cumsum(split_length)
    # MetricsDict will get that supplied in 3 chunks.
    m = MetricsDict()
    for i, end in enumerate(summed):
        start = 0 if i == 0 else summed[i - 1]
        pred = predictions[start:end]
        label = labels[start:end]
        subject_ids = list(range(len(pred)))
        m.add_predictions(subject_ids, pred, label)
    assert m.has_prediction_entries

    # Compute average MetricsDict
    averaged = m.average()

    # Compute additional expected metrics for the averaged MetricsDict
    expected_auc = roc_auc_score(labels, predictions)
    expected_fpr, expected_tpr, thresholds = roc_curve(labels, predictions)
    expected_optimal_idx = np.argmax(expected_tpr - expected_fpr)
    expected_optimal_threshold = float(thresholds[expected_optimal_idx])
    expected_accuracy = np.mean((predictions > expected_optimal_threshold) == labels)

    # Check computed values against expected
    assert averaged.values()[MetricType.OPTIMAL_THRESHOLD.value][0] == pytest.approx(expected_optimal_threshold)
    assert averaged.values()[MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD.value][0] == pytest.approx(expected_accuracy)
    assert averaged.values()[MetricType.FALSE_POSITIVE_RATE_AT_OPTIMAL_THRESHOLD.value][0] == \
           pytest.approx(expected_fpr[expected_optimal_idx])
    assert averaged.values()[MetricType.FALSE_NEGATIVE_RATE_AT_OPTIMAL_THRESHOLD.value][0] == \
           pytest.approx(1 - expected_tpr[expected_optimal_idx])
    assert averaged.values()[MetricType.AREA_UNDER_ROC_CURVE.value][0] == pytest.approx(expected_auc, 1e-6)
def test_metrics_dict_flatten(hues: Optional[List[str]]) -> None:
    m = MetricsDict(hues=hues)
    _hues = hues or [MetricsDict.DEFAULT_HUE_KEY] * 2
    m.add_metric("foo", 1.0, hue=_hues[0])
    m.add_metric("foo", 2.0, hue=_hues[1])
    m.add_metric("bar", 3.0, hue=_hues[0])
    m.add_metric("bar", 4.0, hue=_hues[1])

    if hues is None:
        average = m.average(across_hues=True)
        # We should be able to flatten out all the singleton values that the `average` operation returns
        all_values = list(average.enumerate_single_values())
        assert all_values == [(MetricsDict.DEFAULT_HUE_KEY, "foo", 1.5), (MetricsDict.DEFAULT_HUE_KEY, "bar", 3.5)]
        # When trying to flatten off a dictionary that has two values, this should fail:
        with pytest.raises(ValueError) as ex:
            list(m.enumerate_single_values())
        assert "only hold 1 item" in str(ex)
    else:
        average = m.average(across_hues=False)
        all_values = list(average.enumerate_single_values())
        assert all_values == [('A', 'foo', 1.0), ('A', 'bar', 3.0), ('B', 'foo', 2.0), ('B', 'bar', 4.0)]
def test_metrics_dict_with_default_hue() -> None:
    hue_name = "foo"
    metrics_dict = MetricsDict(hues=[hue_name, MetricsDict.DEFAULT_HUE_KEY])
    assert metrics_dict.get_hue_names(include_default=True) == [hue_name, MetricsDict.DEFAULT_HUE_KEY]
    assert metrics_dict.get_hue_names(include_default=False) == [hue_name]
def train_or_validate_epoch(
        training_steps: ModelTrainingStepsBase
) -> ModelOutputsAndMetricsForEpoch:
    """
    Trains or validates the model for one epoch.
    :param training_steps: Training pipeline to use.
    :returns: The results for training or validation. Result type depends on the type of model that is trained.
    """
    epoch_start_time = time()
    training_random_state = None
    train_val_params = training_steps.train_val_params
    config = training_steps.model_config
    if not train_val_params.in_training_mode:
        # take the snapshot of the existing random state
        training_random_state = RandomStateSnapshot.snapshot_random_state()
        # reset the random state for validation
        ml_util.set_random_seed(config.get_effective_random_seed(),
                                "Model validation")

    status_string = "training" if train_val_params.in_training_mode else "validation"
    item_start_time = time()
    num_load_time_warnings = 0
    num_load_time_exceeded = 0
    num_batches = 0
    total_extra_load_time = 0.0
    total_load_time = 0.0
    model_outputs_epoch = []
    for batch_index, sample in enumerate(train_val_params.data_loader):
        item_finish_time = time()
        item_load_time = item_finish_time - item_start_time
        # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes
        # are spawned. Later, the load time should be zero.
        if batch_index == 0:
            logging.info(
                f"Loaded the first minibatch of {status_string} data in {item_load_time:0.2f} sec."
            )
        elif item_load_time > MAX_ITEM_LOAD_TIME_SEC:
            num_load_time_exceeded += 1
            total_extra_load_time += item_load_time
            if num_load_time_warnings < MAX_LOAD_TIME_WARNINGS:
                logging.warning(
                    f"Loading {status_string} minibatch {batch_index} took {item_load_time:0.2f} sec. "
                    f"This can mean that there are not enough data loader worker processes, or that there "
                    f"is a "
                    f"performance problem in loading. This warning will be printed at most "
                    f"{MAX_LOAD_TIME_WARNINGS} times.")
                num_load_time_warnings += 1
        model_outputs_minibatch = training_steps.forward_and_backward_minibatch(
            sample, batch_index, train_val_params.epoch)
        model_outputs_epoch.append(model_outputs_minibatch)
        train_finish_time = time()
        logging.debug(
            f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: "
            f"Loaded in {item_load_time:0.2f}sec, "
            f"{status_string} in {(train_finish_time - item_finish_time):0.2f}sec. "
            f"Loss = {model_outputs_minibatch.loss}")
        total_load_time += item_finish_time - item_start_time
        num_batches += 1
        item_start_time = time()

    # restore the training random state when validation has finished
    if training_random_state is not None:
        training_random_state.restore_random_state()

    epoch_time_seconds = time() - epoch_start_time
    logging.info(
        f"Epoch {train_val_params.epoch} {status_string} took {epoch_time_seconds:0.2f} sec, "
        f"of which waiting for next minibatch took {total_load_time:0.2f} sec total. {num_batches} "
        "minibatches in total.")
    if num_load_time_exceeded > 0:
        logging.warning(
            "The dataloaders were not fast enough to always supply the next batch in less than "
            f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
        logging.warning(
            f"In this epoch, {num_load_time_exceeded} out of {num_batches} batches exceeded the load time "
            f"threshold. The total loading time for the slow batches was {total_extra_load_time:0.2f}sec."
        )

    _metrics = training_steps.get_epoch_results_and_store(epoch_time_seconds) \
        if train_val_params.save_metrics else MetricsDict()
    return ModelOutputsAndMetricsForEpoch(
        metrics=_metrics,
        model_outputs=model_outputs_epoch,
        is_train=train_val_params.in_training_mode)
Ejemplo n.º 20
0
def calculate_metrics_per_class(
        segmentation: np.ndarray,
        ground_truth: np.ndarray,
        ground_truth_ids: List[str],
        voxel_spacing: TupleFloat3,
        patient_id: Optional[int] = None) -> MetricsDict:
    """
    Calculate the dice for all foreground structures (the background class is completely ignored).
    Returns a MetricsDict with metrics for each of the foreground
    structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
    :param ground_truth_ids: The names of all foreground classes.
    :param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
    :param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]
    :param voxel_spacing: voxel_spacing in 3D Z x Y x X
    :param patient_id: for logging
    """
    number_of_classes = ground_truth.shape[0]
    if len(ground_truth_ids) != (number_of_classes - 1):
        raise ValueError(
            f"Received {len(ground_truth_ids)} foreground class names, but "
            f"the label tensor indicates that there are {number_of_classes - 1} classes."
        )
    binaries = binaries_from_multi_label_array(segmentation, number_of_classes)

    all_classes_are_binary = [
        is_binary_array(ground_truth[label_id])
        for label_id in range(ground_truth.shape[0])
    ]
    if not np.all(all_classes_are_binary):
        raise ValueError("Ground truth values should be 0 or 1")
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    metrics = MetricsDict(hues=ground_truth_ids)
    for i, prediction in enumerate(binaries):
        if i == 0:
            continue
        check_size_matches(prediction,
                           ground_truth[i],
                           arg1_name="prediction",
                           arg2_name="ground_truth")
        if not is_binary_array(prediction):
            raise ValueError("Predictions values should be 0 or 1")
        # simpleitk returns a Dice score of 0 if both ground truth and prediction are all zeros.
        # We want to be able to fish out those cases, and treat them specially later.
        prediction_zero = np.all(prediction == 0)
        gt_zero = np.all(ground_truth[i] == 0)
        dice = mean_surface_distance = hausdorff_distance = math.nan
        if not (prediction_zero and gt_zero):
            prediction_image = sitk.GetImageFromArray(
                prediction.astype(np.uint8))
            prediction_image.SetSpacing(
                sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
            ground_truth_image = sitk.GetImageFromArray(ground_truth[i].astype(
                np.uint8))
            ground_truth_image.SetSpacing(
                sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
            overlap_measures_filter.Execute(prediction_image,
                                            ground_truth_image)
            dice = overlap_measures_filter.GetDiceCoefficient()
            if prediction_zero or gt_zero:
                hausdorff_distance = mean_surface_distance = math.inf
            else:
                try:
                    hausdorff_distance_filter.Execute(prediction_image,
                                                      ground_truth_image)
                    hausdorff_distance = hausdorff_distance_filter.GetHausdorffDistance(
                    )
                except Exception as e:
                    logging.warning(
                        "Cannot calculate Hausdorff distance for "
                        f"structure {i} of patient {patient_id}: {e}")
                try:
                    mean_surface_distance = surface_distance(
                        prediction_image, ground_truth_image)
                except Exception as e:
                    logging.warning(
                        f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}"
                    )
            logging.debug(
                f"Patient {patient_id}, class {i} has Dice score {dice}")

        def add_metric(metric_type: MetricType, value: float) -> None:
            metrics.add_metric(metric_type,
                               value,
                               skip_nan_when_averaging=True,
                               hue=ground_truth_ids[i - 1])

        add_metric(MetricType.DICE, dice)
        add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance)
        add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance)
    return metrics