def test_metrics_dict_per_subject() -> None:
    """
    Ensure that adding per-subject predictions can correctly handle subject IDs
    """
    hue1 = "H1"
    hue2 = "H2"
    m = ScalarMetricsDict(hues=[hue1, hue2], is_classification_metrics=True)
    m.add_predictions(["S1", "S2"], np.array([0.0, 1.0]), np.array([0.0, 1.0]), hue=hue1)
    m.add_predictions(["S1", "S2"], np.array([1.0, 0.0]), np.array([0.0, 1.0]), hue=hue2)
    predictions = m.get_predictions_and_labels_per_subject(hue=hue1)
    assert len(predictions) == 2
def test_metrics_dic_subject_ids() -> None:
    hue1 = "H1"
    m = ScalarMetricsDict(hues=[hue1], is_classification_metrics=True)
    m.add_predictions(subject_ids=[0], predictions=np.zeros(1), labels=np.zeros(1), hue=hue1)
    assert m.subject_ids() == []
    assert m.subject_ids(hue=hue1) == [0]
Beispiel #3
0
def compute_scalar_metrics(
        metrics_dict: ScalarMetricsDict,
        subject_ids: Sequence[str],
        model_output: torch.Tensor,
        labels: torch.Tensor,
        loss_type: ScalarLoss = ScalarLoss.BinaryCrossEntropyWithLogits
) -> None:
    """
    Computes various metrics for a binary classification task from real-valued model output and a label vector,
    and stores them in the given `metrics_dict`.
    The model output is assumed to be in the range between 0 and 1, a value larger than 0.5 indicates a prediction
    of class 1. The label vector is expected to contain class indices 0 and 1 only.
    Metrics for each model output channel will be isolated, and a non-default hue for each model output channel is
    expected, and must exist in the provided metrics_dict. The Default hue is used for single model outputs.
    :param metrics_dict: An object that holds all metrics. It will be updated in-place.
    :param subject_ids: Subject ids for the model output and labels.
    :param model_output: A tensor containing model outputs.
    :param labels: A tensor containing class labels.
    :param loss_type: The type of loss that the model uses. This is required to optionally convert 2-dim model output
    to probabilities.
    """
    _model_output_channels = model_output.shape[1]
    model_output_hues = metrics_dict.get_hue_names(
        include_default=len(metrics_dict.hues_without_default) == 0)

    if len(model_output_hues) < _model_output_channels:
        raise ValueError(
            "Hues must be provided for each model output channel, found "
            f"{_model_output_channels} channels but only {len(model_output_hues)} hues"
        )

    for i, hue in enumerate(model_output_hues):
        # mask the model outputs and labels if required
        masked_model_outputs_and_labels = get_masked_model_outputs_and_labels(
            model_output[:, i, ...], labels[:, i, ...], subject_ids)

        # compute metrics on valid masked tensors only
        if masked_model_outputs_and_labels is not None:
            _model_output, _labels, _subject_ids = \
                masked_model_outputs_and_labels.model_outputs.data, \
                masked_model_outputs_and_labels.labels.data, \
                masked_model_outputs_and_labels.subject_ids

            if loss_type == ScalarLoss.MeanSquaredError:
                metrics = {
                    MetricType.MEAN_SQUARED_ERROR:
                    F.mse_loss(_model_output,
                               _labels.float(),
                               reduction='mean').item(),
                    MetricType.MEAN_ABSOLUTE_ERROR:
                    mean_absolute_error(_model_output, _labels),
                    MetricType.R2_SCORE:
                    r2_score(_model_output, _labels)
                }
            else:
                metrics = {
                    MetricType.CROSS_ENTROPY:
                    F.binary_cross_entropy(_model_output,
                                           _labels.float(),
                                           reduction='mean').item(),
                    MetricType.ACCURACY_AT_THRESHOLD_05:
                    binary_classification_accuracy(_model_output, _labels)
                }
            for key, value in metrics.items():
                if key == MetricType.R2_SCORE:
                    # For a batch size 1, R2 score can be nan. We need to ignore nans
                    # when average in case the last batch is of size 1.
                    metrics_dict.add_metric(key,
                                            value,
                                            skip_nan_when_averaging=True,
                                            hue=hue)
                else:
                    metrics_dict.add_metric(key, value, hue=hue)

            assert _subject_ids is not None
            metrics_dict.add_predictions(_subject_ids,
                                         _model_output.detach().cpu().numpy(),
                                         _labels.cpu().numpy(),
                                         hue=hue)