Пример #1
0
def test_multiclass_metrics(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    zero_division: int,
    true_values: Dict[str, float],
) -> None:
    """
    Test multiclass metric
    Args:
        outputs: tensor of predictions
        targets: tensor of targets
        zero_division: zero division policy flag
        true_values: true values of metrics
    """
    metric = MulticlassPrecisionRecallF1SupportMetric(
        num_classes=num_classes, zero_division=zero_division)
    metric.update(outputs=outputs, targets=targets)
    metrics = metric.compute_key_value()
    for key in true_values:
        assert key in metrics
        assert abs(metrics[key] - true_values[key]) < EPS
Пример #2
0
def test_update_key_value_multiclass(
    outputs_list: Iterable[torch.Tensor],
    targets_list: Iterable[torch.Tensor],
    num_classes: int,
    zero_division: int,
    update_true_values: Iterable[Dict[str, float]],
    compute_true_value: Dict[str, float],
) -> None:
    """
    This test checks that metrics update works correctly with multiple calls.
    Metric should update statistics and return metrics for tmp input, so in this test
    we call update_key_value multiple times and check that all the intermediate metrics values
    are correct. After all the updates it checks that metrics computed with accumulated
    statistics are correct too.

    Args:
        outputs_list: sequence of predictions
        targets_list: sequence of targets
        num_classes: number of classes
        zero_division: int value, should be 0 or 1; return it in metrics in case of zero division
        update_true_values: sequence of true intermediate metrics
        compute_true_value: total metrics value for all the items from output_list and targets_list
    """
    metric = MulticlassPrecisionRecallF1SupportMetric(
        num_classes=num_classes, zero_division=zero_division)
    for outputs, targets, update_true_value in zip(outputs_list, targets_list,
                                                   update_true_values):
        intermediate_metrics = metric.update_key_value(outputs=outputs,
                                                       targets=targets)
        for key in update_true_value:
            assert key in intermediate_metrics
            assert abs(intermediate_metrics[key] -
                       update_true_value[key]) < EPS
    metrics = metric.compute_key_value()
    for key in compute_true_value:
        assert key in metrics
        assert abs(metrics[key] - compute_true_value[key]) < EPS
Пример #3
0
def test_update(
    outputs_list: Iterable[torch.Tensor],
    targets_list: Iterable[torch.Tensor],
    num_classes: int,
    zero_division: int,
    true_values: Dict[str, float],
) -> None:
    """
    Test if metric works correctly while updating.
    Args:
        outputs_list: list of tensors of predictions
        targets_list: list of tensors of targets
        num_classes: number of classes to score
        zero_division: zero division policy flag
        true_values: true values of metrics
    """
    metric = MulticlassPrecisionRecallF1SupportMetric(
        num_classes=num_classes, zero_division=zero_division)
    for outputs, targets in zip(outputs_list, targets_list):
        metric.update(outputs=outputs, targets=targets)
    metrics = metric.compute_key_value()
    for key in true_values:
        assert key in metrics
        assert abs(metrics[key] - true_values[key]) < EPS