Exemple #1
0
    def test_get_lone_query_labels_multi_dim(self):
        def custom_label_comparison_fn(x, y):
            return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

        query_labels = np.array([
            (1, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 2),
            (1, 2),
            (4, 5),
        ])

        for comparison_fn in [
                accuracy_calculator.EQUALITY, custom_label_comparison_fn
        ]:
            label_counts, num_k = accuracy_calculator.get_label_match_counts(
                query_labels,
                query_labels,
                comparison_fn,
            )

            if comparison_fn is accuracy_calculator.EQUALITY:
                correct = [
                    (
                        True,
                        np.array([[0, 2], [1, 2], [1, 3], [4, 5]]),
                        np.array(
                            [False, True, True, True, False, False, False]),
                    ),
                    (
                        False,
                        np.array([[]]),
                        np.array([True, True, True, True, True, True, True]),
                    ),
                ]
            else:
                correct_lone = np.array([[4, 5]])
                correct_mask = np.array(
                    [True, True, True, True, True, True, False])
                correct = [
                    (True, correct_lone, correct_mask),
                    (False, correct_lone, correct_mask),
                ]

            for same_source, correct_lone, correct_mask in correct:
                (
                    lone_query_labels,
                    not_lone_query_mask,
                ) = accuracy_calculator.get_lone_query_labels(
                    query_labels, label_counts, same_source, comparison_fn)
                if correct_lone.size == 0:
                    self.assertTrue(lone_query_labels.size == 0)
                else:
                    self.assertTrue(np.all(lone_query_labels == correct_lone))

                self.assertTrue(np.all(not_lone_query_mask == correct_mask))
Exemple #2
0
    def test_get_lone_query_labels(self):
        query_labels = np.array([0, 1, 2, 3, 4, 5, 6])
        reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5])
        label_counts, _ = accuracy_calculator.get_label_match_counts(
            query_labels,
            reference_labels,
            accuracy_calculator.EQUALITY,
        )

        for same_source, correct in [
            (True, np.array([1, 3, 4, 5, 6])),
            (False, np.array([6])),
        ]:
            lone_query_labels, _ = accuracy_calculator.get_lone_query_labels(
                query_labels,
                label_counts,
                same_source,
                accuracy_calculator.EQUALITY,
            )
            self.assertTrue(np.all(lone_query_labels == correct))
Exemple #3
0
    def test_get_lone_query_labels(self):
        query_labels = torch.tensor([0, 1, 2, 3, 4, 5, 6], device=TEST_DEVICE)
        reference_labels = torch.tensor([0, 0, 0, 1, 2, 2, 3, 4, 5],
                                        device=TEST_DEVICE)
        label_counts = accuracy_calculator.get_label_match_counts(
            query_labels,
            reference_labels,
            accuracy_calculator.EQUALITY,
        )

        for same_source, correct in [
            (True, torch.tensor([1, 3, 4, 5, 6], device=TEST_DEVICE)),
            (False, torch.tensor([6], device=TEST_DEVICE)),
        ]:
            lone_query_labels, _ = accuracy_calculator.get_lone_query_labels(
                query_labels,
                label_counts,
                same_source,
                accuracy_calculator.EQUALITY,
            )
            self.assertTrue(torch.all(lone_query_labels == correct))
Exemple #4
0
    def test_get_lone_query_labels_multi_dim(self):
        def equality2D(x, y):
            return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1])

        def custom_label_comparison_fn(x, y):
            return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

        query_labels = torch.tensor(
            [
                (1, 3),
                (0, 3),
                (0, 3),
                (0, 3),
                (1, 2),
                (4, 5),
            ],
            device=TEST_DEVICE,
        )

        for comparison_fn in [equality2D, custom_label_comparison_fn]:
            label_counts = accuracy_calculator.get_label_match_counts(
                query_labels,
                query_labels,
                comparison_fn,
            )

            unique_labels, counts = label_counts
            correct_unique_labels = torch.tensor(
                [[0, 3], [1, 2], [1, 3], [4, 5]], device=TEST_DEVICE)
            if comparison_fn is equality2D:
                correct_counts = torch.tensor([3, 1, 1, 1], device=TEST_DEVICE)
            else:
                correct_counts = torch.tensor([0, 1, 1, 0], device=TEST_DEVICE)

            self.assertTrue(torch.all(correct_counts == counts))
            self.assertTrue(torch.all(correct_unique_labels == unique_labels))

            if comparison_fn is equality2D:
                correct = [
                    (
                        True,
                        torch.tensor([[1, 2], [1, 3], [4, 5]],
                                     device=TEST_DEVICE),
                        torch.tensor([False, True, True, True, False, False],
                                     device=TEST_DEVICE),
                    ),
                    (
                        False,
                        torch.tensor([[]], device=TEST_DEVICE),
                        torch.tensor([True, True, True, True, True, True],
                                     device=TEST_DEVICE),
                    ),
                ]
            else:
                correct = [
                    (
                        True,
                        torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE),
                        torch.tensor([True, False, False, False, True, False],
                                     device=TEST_DEVICE),
                    ),
                    (
                        False,
                        torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE),
                        torch.tensor([True, False, False, False, True, False],
                                     device=TEST_DEVICE),
                    ),
                ]

            for same_source, correct_lone, correct_mask in correct:
                (
                    lone_query_labels,
                    not_lone_query_mask,
                ) = accuracy_calculator.get_lone_query_labels(
                    query_labels, label_counts, same_source, comparison_fn)
                if correct_lone.numel() == 0:
                    self.assertTrue(lone_query_labels.numel() == 0)
                else:
                    self.assertTrue(
                        torch.all(lone_query_labels == correct_lone))

                self.assertTrue(torch.all(not_lone_query_mask == correct_mask))
Exemple #5
0
    def test_get_lone_query_labels_custom(self):
        def fn1(x, y):
            return abs(x - y) < 2

        def fn2(x, y):
            return abs(x - y) > 99

        query_labels = torch.tensor([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100],
                                    device=TEST_DEVICE)

        for comparison_fn in [fn1, fn2]:
            correct_unique_labels = torch.tensor(
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100], device=TEST_DEVICE)

            if comparison_fn is fn1:
                correct_counts = torch.tensor(
                    [3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1], device=TEST_DEVICE)
                correct_lone_query_labels = torch.tensor([100],
                                                         device=TEST_DEVICE)
                correct_not_lone_query_mask = torch.tensor(
                    [
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        True,
                        False,
                    ],
                    device=TEST_DEVICE,
                )
            elif comparison_fn is fn2:
                correct_counts = torch.tensor(
                    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], device=TEST_DEVICE)
                correct_lone_query_labels = torch.tensor(
                    [1, 2, 3, 4, 5, 6, 7, 8, 9], device=TEST_DEVICE)
                correct_not_lone_query_mask = torch.tensor(
                    [
                        True,
                        True,
                        False,
                        False,
                        False,
                        False,
                        False,
                        False,
                        False,
                        False,
                        False,
                        True,
                    ],
                    device=TEST_DEVICE,
                )

            label_counts = accuracy_calculator.get_label_match_counts(
                query_labels,
                query_labels,
                comparison_fn,
            )
            unique_labels, counts = label_counts

            self.assertTrue(torch.all(unique_labels == correct_unique_labels))
            self.assertTrue(torch.all(counts == correct_counts))

            (
                lone_query_labels,
                not_lone_query_mask,
            ) = accuracy_calculator.get_lone_query_labels(
                query_labels, label_counts, True, comparison_fn)

            self.assertTrue(
                torch.all(lone_query_labels == correct_lone_query_labels))
            self.assertTrue(
                torch.all(not_lone_query_mask == correct_not_lone_query_mask))
    def test_get_lone_query_labels_custom(self):
        def fn1(x, y):
            return abs(x - y) < 2

        def fn2(x, y):
            return abs(x - y) > 99

        query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100])

        for comparison_fn in [fn1, fn2]:
            correct_unique_labels = np.array(
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100])

            if comparison_fn is fn1:
                correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1])
                correct_lone_query_labels = np.array([100])
                correct_not_lone_query_mask = np.array([
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    True,
                    False,
                ])
            elif comparison_fn is fn2:
                correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2])
                correct_lone_query_labels = np.array(
                    [1, 2, 3, 4, 5, 6, 7, 8, 9])
                correct_not_lone_query_mask = np.array([
                    True,
                    True,
                    False,
                    False,
                    False,
                    False,
                    False,
                    False,
                    False,
                    False,
                    False,
                    True,
                ])

            label_counts, num_k = accuracy_calculator.get_label_match_counts(
                query_labels,
                query_labels,
                comparison_fn,
            )
            unique_labels, counts = label_counts

            self.assertTrue(np.all(unique_labels == correct_unique_labels))
            self.assertTrue(np.all(counts == correct_counts))

            (
                lone_query_labels,
                not_lone_query_mask,
            ) = accuracy_calculator.get_lone_query_labels(
                query_labels, label_counts, True, comparison_fn)

            self.assertTrue(
                np.all(lone_query_labels == correct_lone_query_labels))
            self.assertTrue(
                np.all(not_lone_query_mask == correct_not_lone_query_mask))