Esempio n. 1
0
    def test_top_1_accuracy_distributed_uneven_batch(self):
        # Simulate test on 2 process DDP execution
        accuracy = TopKClassificationAccuracy(top_k=None)

        proc1_acc = accuracy(logits=self.top_k_logits,
                             labels=torch.tensor([0, 0, 2]))
        correct1, total1 = accuracy.correct_counts_k, accuracy.total_counts_k

        proc2_acc = accuracy(
            logits=torch.flip(
                self.top_k_logits,
                dims=[1])[:2, :],  # reverse logits, select first 2 samples
            labels=torch.tensor([2, 0]),
        )  # reduce number of labels
        correct2, total2 = accuracy.correct_counts_k, accuracy.total_counts_k

        correct = torch.stack([correct1, correct2])
        total = torch.stack([total1, total2])

        assert correct.shape == torch.Size([2, 1])
        assert total.shape == torch.Size([2, 1])

        assert abs(proc1_acc[0] - 0.667) < 1e-3  # 2/3
        assert abs(proc2_acc[0] - 0.500) < 1e-3  # 1/2

        accuracy.correct_counts_k = torch.tensor([correct.sum()])
        accuracy.total_counts_k = torch.tensor([total.sum()])
        acc_topk = accuracy.compute()
        acc_top1 = acc_topk[0]

        assert abs(acc_top1 - 0.6) < 1e-3  # 3/5
Esempio n. 2
0
    def test_top_1_accuracy_distributed(self):
        # Simulate test on 2 process DDP execution
        labels = torch.tensor([[0, 0, 2], [2, 0, 0]], dtype=torch.long)

        accuracy = TopKClassificationAccuracy(top_k=None)
        proc1_acc = accuracy(logits=self.top_k_logits, labels=labels[0])
        correct1, total1 = accuracy.correct_counts_k, accuracy.total_counts_k

        accuracy.reset()
        proc2_acc = accuracy(logits=torch.flip(self.top_k_logits, dims=[1]),
                             labels=labels[1])  # reverse logits
        correct2, total2 = accuracy.correct_counts_k, accuracy.total_counts_k

        correct = torch.stack([correct1, correct2])
        total = torch.stack([total1, total2])

        assert correct.shape == torch.Size([2, 1])
        assert total.shape == torch.Size([2, 1])

        assert abs(proc1_acc[0] - 0.667) < 1e-3  # 2/3
        assert abs(proc2_acc[0] - 0.333) < 1e-3  # 1/3

        accuracy.reset()
        accuracy.correct_counts_k = torch.tensor([correct.sum()])
        accuracy.total_counts_k = torch.tensor([total.sum()])
        acc_topk = accuracy.compute()
        acc_top1 = acc_topk[0]

        assert abs(acc_top1 - 0.5) < 1e-3  # 3/6