def test_accuracy_computation(self):
        accuracy = BooleanAccuracy()
        predictions = torch.Tensor([[0, 1],
                                    [2, 3],
                                    [4, 5],
                                    [6, 7]])
        targets = torch.Tensor([[0, 1],
                                [2, 2],
                                [4, 5],
                                [7, 7]])
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 2. / 4

        mask = torch.ones(4, 2)
        mask[1, 1] = 0
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric() == 5. / 8

        targets[1, 1] = 3
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 8. / 12

        accuracy.reset()
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 3. / 4
Esempio n. 2
0
    def test_accuracy_computation(self):
        accuracy = BooleanAccuracy()
        predictions = torch.Tensor([[0, 1], [2, 3], [4, 5], [6, 7]])
        targets = torch.Tensor([[0, 1], [2, 2], [4, 5], [7, 7]])
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 2. / 4

        mask = torch.ones(4, 2)
        mask[1, 1] = 0
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric() == 5. / 8

        targets[1, 1] = 3
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 8. / 12

        accuracy.reset()
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 3. / 4
    def test_accuracy_computation(self, device: str):
        accuracy = BooleanAccuracy()
        predictions = torch.tensor([[0, 1], [2, 3], [4, 5], [6, 7]], device=device)
        targets = torch.tensor([[0, 1], [2, 2], [4, 5], [7, 7]], device=device)
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 2 / 4

        mask = torch.ones(4, 2, device=device).bool()
        mask[1, 1] = 0
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric() == 5 / 8

        targets[1, 1] = 3
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 8 / 12

        accuracy.reset()
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 3 / 4