Example #1
0
def test_accuracy_update(
    outputs_list: List[torch.Tensor],
    targets_list: List[torch.Tensor],
    num_classes: int,
    topk: List[int],
    true_values_list: List[Dict[str, float]],
) -> None:
    """
    This test checks that AccuracyMetric updates its values correctly and return
    correct intermediate results
    Note that now `accuracy/std` is not std exactly so it can fail if you fix it.

    Args:
        outputs_list: list of output tensors
        targets_list: list of target tensors
        num_classes: number od classes for classification
        topk: topk args for computing accuracy@topk
        true_values_list: list of correct metrics intermediate values
    """
    metric = AccuracyMetric(topk_args=topk, num_classes=num_classes)
    for outputs, targets, true_values in zip(outputs_list, targets_list,
                                             true_values_list):
        metric.update(logits=outputs, targets=targets)
        intermediate_metric_values = metric.compute_key_value()
        for key in true_values.keys():
            assert key in intermediate_metric_values
            assert np.isclose(true_values[key],
                              intermediate_metric_values[key])
Example #2
0
def test_accuracy(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    topk: List[int],
    true_values: Dict[str, float],
) -> None:
    """
    Test multiclass accuracy with different topk args
    Note that now `accuracy/std` is not std exactly so it can fail if you fix it.

    Args:
        outputs: tensor of outputs
        targets: tensor of targets
        num_classes: number of classes for classification
        topk: list of topk args for accuracy@topk
        true_values: true metrics values
    """
    metric = AccuracyMetric(topk_args=topk)
    metric.update(logits=outputs, targets=targets)
    metrics = metric.compute_key_value()
    for key in true_values.keys():
        assert key in metrics
        assert np.isclose(true_values[key], metrics[key])
Example #3
0
 def __init__(
     self,
     input_key: str,
     target_key: str,
     topk_args: List[int] = None,
     num_classes: int = None,
     log_on_batch: bool = True,
     prefix: str = None,
     suffix: str = None,
 ):
     """Init."""
     super().__init__(
         metric=AccuracyMetric(
             topk_args=topk_args, num_classes=num_classes, prefix=prefix, suffix=suffix
         ),
         input_key=input_key,
         target_key=target_key,
         log_on_batch=log_on_batch,
     )