def test_compute(): acc = TopKCategoricalAccuracy(2) y_pred = torch.FloatTensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) y = torch.ones(2).type(torch.LongTensor) acc.update((y_pred, y)) assert acc.compute() == 0.5 acc.reset() y_pred = torch.FloatTensor([[0.4, 0.8, 0.2, 0.6], [0.8, 0.6, 0.4, 0.2]]) y = torch.ones(2).type(torch.LongTensor) acc.update((y_pred, y)) assert acc.compute() == 1.0
def test_compute(): acc = TopKCategoricalAccuracy(2) y_pred = torch.FloatTensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) y = torch.ones(2).long() acc.update((y_pred, y)) assert isinstance(acc.compute(), float) assert acc.compute() == 0.5 acc.reset() y_pred = torch.FloatTensor([[0.4, 0.8, 0.2, 0.6], [0.8, 0.6, 0.4, 0.2]]) y = torch.ones(2).long() acc.update((y_pred, y)) assert isinstance(acc.compute(), float) assert acc.compute() == 1.0