def test_accuracy_top1(): """ Tests for catalyst.metrics.accuracy metric. """ for i in range(NUM_CLASSES): outputs = torch.zeros((BATCH_SIZE, NUM_CLASSES)) outputs[:, i] = 1 targets = torch.ones((BATCH_SIZE, 1)) * i top1, top3, top5 = accuracy(outputs, targets, topk=(1, 3, 5)) assert np.isclose(top1, 1) assert np.isclose(top3, 1) assert np.isclose(top5, 1)
def test_accuracy_top3(): """ Tests for catalyst.metrics.accuracy metric. """ outputs = (torch.linspace(0, NUM_CLASSES - 1, steps=NUM_CLASSES).repeat( 1, BATCH_SIZE).view(-1, NUM_CLASSES)) for i in range(NUM_CLASSES): targets = torch.ones((BATCH_SIZE, 1)) * i top1, top3, top5 = accuracy(outputs, targets, topk=(1, 3, 5)) assert np.isclose(top1, 1 if i >= NUM_CLASSES - 1 else 0) assert np.isclose(top3, 1 if i >= NUM_CLASSES - 3 else 0) assert np.isclose(top5, 1 if i >= NUM_CLASSES - 5 else 0)
def update(self, logits: torch.Tensor, targets: torch.Tensor) -> List[float]: """ Updates metric value with accuracy for new data and return intermediate metrics values. Args: logits: tensor of logits targets: tensor of targets Returns: list of accuracy@k values """ values = accuracy(logits, targets, topk=self.topk_args) values = [v.item() for v in values] for value, metric in zip(values, self.additive_metrics): metric.update(value, len(targets)) return values