def test_accuracy_update( outputs_list: List[torch.Tensor], targets_list: List[torch.Tensor], num_classes: int, topk: List[int], true_values_list: List[Dict[str, float]], ) -> None: """ This test checks that AccuracyMetric updates its values correctly and return correct intermediate results Note that now `accuracy/std` is not std exactly so it can fail if you fix it. Args: outputs_list: list of output tensors targets_list: list of target tensors num_classes: number od classes for classification topk: topk args for computing accuracy@topk true_values_list: list of correct metrics intermediate values """ metric = AccuracyMetric(topk_args=topk, num_classes=num_classes) for outputs, targets, true_values in zip(outputs_list, targets_list, true_values_list): metric.update(logits=outputs, targets=targets) intermediate_metric_values = metric.compute_key_value() for key in true_values.keys(): assert key in intermediate_metric_values assert np.isclose(true_values[key], intermediate_metric_values[key])
def test_accuracy( outputs: torch.Tensor, targets: torch.Tensor, num_classes: int, topk: List[int], true_values: Dict[str, float], ) -> None: """ Test multiclass accuracy with different topk args Note that now `accuracy/std` is not std exactly so it can fail if you fix it. Args: outputs: tensor of outputs targets: tensor of targets num_classes: number of classes for classification topk: list of topk args for accuracy@topk true_values: true metrics values """ metric = AccuracyMetric(topk_args=topk) metric.update(logits=outputs, targets=targets) metrics = metric.compute_key_value() for key in true_values.keys(): assert key in metrics assert np.isclose(true_values[key], metrics[key])
def __init__( self, input_key: str, target_key: str, topk_args: List[int] = None, num_classes: int = None, log_on_batch: bool = True, prefix: str = None, suffix: str = None, ): """Init.""" super().__init__( metric=AccuracyMetric( topk_args=topk_args, num_classes=num_classes, prefix=prefix, suffix=suffix ), input_key=input_key, target_key=target_key, log_on_batch=log_on_batch, )