示例#1
0
    def _test_accuracy_calculator_and_faiss_avg_of_avgs(self, use_numpy):
        AC_global_average = accuracy_calculator.AccuracyCalculator(
            avg_of_avgs=False)
        AC_per_class_average = accuracy_calculator.AccuracyCalculator(
            avg_of_avgs=True)

        query_labels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
        reference_labels = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        if use_numpy:
            query = np.arange(10)[:, None]
            reference = np.arange(10)[:, None]
            query_labels = np.array(query_labels)
            reference_labels = np.array(reference_labels)
        else:
            query = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
            reference = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
            query_labels = torch.tensor(query_labels, device=TEST_DEVICE)
            reference_labels = torch.tensor(reference_labels,
                                            device=TEST_DEVICE)

        query[-1] = 100
        reference[0] = -100

        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.9))
        self.assertTrue(isclose(acc["r_precision"], 0.9))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.9))

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.5))
        self.assertTrue(isclose(acc["r_precision"], 0.5))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.5))
示例#2
0
    def test_accuracy_calculator_and_faiss_avg_of_avgs(self):
        AC_global_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"), avg_of_avgs=False
        )
        AC_per_class_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"), avg_of_avgs=True
        )
        query = np.arange(10)[:, None].astype(np.float32)
        reference = np.arange(10)[:, None].astype(np.float32)
        query[-1] = 100
        reference[0] = -100
        query_labels = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
        reference_labels = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        acc = AC_global_average.get_accuracy(
            query, reference, query_labels, reference_labels, False
        )
        self.assertTrue(acc["precision_at_1"] == 0.9)
        self.assertTrue(acc["r_precision"] == 0.9)
        self.assertTrue(acc["mean_average_precision_at_r"] == 0.9)

        acc = AC_per_class_average.get_accuracy(
            query, reference, query_labels, reference_labels, False
        )
        self.assertTrue(acc["precision_at_1"] == 0.5)
        self.assertTrue(acc["r_precision"] == 0.5)
        self.assertTrue(acc["mean_average_precision_at_r"] == 0.5)
    def test_accuracy_calculator_float_custom_comparison_function(self):
        def label_comparison_fn(x, y):
            return torch.abs(x - y) < 1

        self.assertRaises(
            NotImplementedError,
            lambda: accuracy_calculator.AccuracyCalculator(
                include=("NMI", "AMI"),
                avg_of_avgs=False,
                label_comparison_fn=label_comparison_fn,
            ),
        )

        AC_global_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"),
            avg_of_avgs=False,
            label_comparison_fn=label_comparison_fn,
        )

        query = torch.tensor([0, 3], device=TEST_DEVICE).unsqueeze(1)
        reference = torch.arange(4, device=TEST_DEVICE).unsqueeze(1)
        query_labels = torch.tensor(
            [
                0.01,
                0.02,
            ],
            device=TEST_DEVICE,
        )
        reference_labels = torch.tensor(
            [
                10.0,
                0.03,
                0.04,
                0.05,
            ],
            device=TEST_DEVICE,
        )

        correct = {
            "precision_at_1": (0 + 1) / 2,
            "r_precision": (2 / 3 + 3 / 3) / 2,
            "mean_average_precision_at_r": ((0 + 1 / 2 + 2 / 3) / 3 + 1) / 2,
        }
        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        for k in correct:
            self.assertTrue(isclose(acc[k], correct[k]))

        correct = {
            "precision_at_1": 1.0,
            "r_precision": 1.0,
            "mean_average_precision_at_r": 1.0,
        }
        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, True)
        for k in correct:
            self.assertTrue(isclose(acc[k], correct[k]))
示例#4
0
    def test_accuracy_calculator_large_k(self):
        for ecfss in [False, True]:
            for max_k in [None, "max_bin_count"]:
                for num_embeddings in [1000, 2100]:
                    # make random features
                    encs = np.random.rand(num_embeddings, 5).astype(np.float32)
                    # and random labels of 100 classes
                    labels = np.zeros((num_embeddings // 100, 100),
                                      dtype=np.int32)
                    for i in range(10):
                        labels[i] = np.arange(100)
                    labels = labels.ravel()

                    correct_p1, correct_map, correct_mapr = self.evaluate(
                        encs, labels, max_k, ecfss)

                    # use Musgrave's library
                    if max_k is None:
                        k = len(encs) - 1 if ecfss else len(encs)
                        accs = [
                            accuracy_calculator.AccuracyCalculator(),
                            accuracy_calculator.AccuracyCalculator(k=k),
                        ]
                    elif max_k == "max_bin_count":
                        accs = [
                            accuracy_calculator.AccuracyCalculator(
                                k="max_bin_count")
                        ]

                    for acc in accs:
                        d = acc.get_accuracy(
                            encs,
                            encs,
                            labels,
                            labels,
                            ecfss,
                            include=(
                                "mean_average_precision",
                                "mean_average_precision_at_r",
                                "precision_at_1",
                            ),
                        )

                        self.assertTrue(
                            np.isclose(correct_p1, d["precision_at_1"]))
                        self.assertTrue(
                            np.isclose(correct_map,
                                       d["mean_average_precision"]))
                        self.assertTrue(
                            np.isclose(correct_mapr,
                                       d["mean_average_precision_at_r"]))
    def test_global_embedding_space_tester(self):
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(
            include=("precision_at_1", ))

        correct = {
            "compared_to_self": {
                "train": 1,
                "val": 6.0 / 8
            },
            "compared_to_sets_combined": {
                "train": 1.0 / 8,
                "val": 1.0 / 8
            },
            "compared_to_training_set": {
                "train": 1,
                "val": 1.0 / 8
            },
        }

        for reference_set, correct_vals in correct.items():
            tester = GlobalEmbeddingSpaceTester(reference_set=reference_set,
                                                accuracy_calculator=AC)
            tester.test(self.dataset_dict, 0, model)
            self.assertTrue(tester.all_accuracies["train"]
                            ["precision_at_1_level0"] == correct_vals["train"])
            self.assertTrue(tester.all_accuracies["val"]
                            ["precision_at_1_level0"] == correct_vals["val"])
    def test_accuracy_calculator(self):
        query_labels = np.array([0, 1, 2, 3, 4])
        knn_labels = np.array([[0, 1, 1, 2, 2], [1, 0, 1, 1,
                                                 3], [4, 4, 4, 4, 2],
                               [3, 1, 3, 1, 3], [0, 0, 4, 2, 2]])
        label_counts = {0: 2, 1: 3, 2: 5, 3: 4, 4: 5}
        AC = accuracy_calculator.AccuracyCalculator(exclude_metrics=("NMI",
                                                                     "AMI"))
        kwargs = {
            "query_labels": query_labels,
            "label_counts": label_counts,
            "knn_labels": knn_labels
        }

        function_dict = AC.get_function_dict()

        for ecfss in [False, True]:
            if ecfss:
                kwargs["knn_labels"] = kwargs["knn_labels"][:, 1:]
            kwargs["embeddings_come_from_same_source"] = ecfss
            acc = AC._get_accuracy(function_dict, **kwargs)
            self.assertTrue(
                acc["precision_at_1"] == self.correct_precision_at_1(ecfss))
            self.assertTrue(
                acc["r_precision"] == self.correct_r_precision(ecfss))
            self.assertTrue(acc["mean_average_precision_at_r"] ==
                            self.correct_mean_average_precision_at_r(ecfss))
    def test_accuracy_calculator_and_faiss_with_numpy_input(self):
        AC = accuracy_calculator.AccuracyCalculator()

        query = np.arange(10)[:, None]
        reference = np.arange(10)[:, None]
        query_labels = np.arange(10)
        reference_labels = np.arange(10)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels,
                              False)
        self.assertTrue(isclose(acc["precision_at_1"], 1))
        self.assertTrue(isclose(acc["r_precision"], 1))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 1))

        reference = (np.arange(20) / 2.0)[:, None]
        reference_labels = np.zeros(20)
        reference_labels[::2] = query_labels
        reference_labels[1::2] = np.ones(10)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels,
                              True)
        self.assertTrue(isclose(acc["precision_at_1"], 1))
        self.assertTrue(isclose(acc["r_precision"], 0.5))
        self.assertTrue(
            isclose(
                acc["mean_average_precision_at_r"],
                (1 + 2.0 / 2 + 3.0 / 5 + 4.0 / 7 + 5.0 / 9) / 10,
            ))
    def test_accuracy_calculator_and_faiss(self):
        AC = accuracy_calculator.AccuracyCalculator()

        query = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
        reference = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
        query_labels = torch.arange(10, device=TEST_DEVICE)
        reference_labels = torch.arange(10, device=TEST_DEVICE)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels,
                              False)
        self.assertTrue(isclose(acc["precision_at_1"], 1))
        self.assertTrue(isclose(acc["r_precision"], 1))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 1))

        reference = (torch.arange(20, device=TEST_DEVICE) / 2.0).unsqueeze(1)
        reference_labels = torch.zeros(20, device=TEST_DEVICE)
        reference_labels[::2] = query_labels
        reference_labels[1::2] = torch.ones(10)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels,
                              True)
        self.assertTrue(isclose(acc["precision_at_1"], 1))
        self.assertTrue(isclose(acc["r_precision"], 0.5))
        self.assertTrue(
            isclose(
                acc["mean_average_precision_at_r"],
                (1 + 2.0 / 2 + 3.0 / 5 + 4.0 / 7 + 5.0 / 9) / 10,
            ))
示例#9
0
    def test_global_embedding_space_tester(self):
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(
            include=("precision_at_1", ))

        correct = [
            (None, {
                "train": 1,
                "val": 6.0 / 8
            }),
            (
                [("train", ["train", "val"]), ("val", ["train", "val"])],
                {
                    "train": 1.0 / 8,
                    "val": 1.0 / 8
                },
            ),
            ([("train", ["train"]), ("val", ["train"])], {
                "train": 1,
                "val": 1.0 / 8
            }),
        ]

        for splits_to_eval, correct_vals in correct:
            tester = GlobalEmbeddingSpaceTester(accuracy_calculator=AC)
            tester.test(self.dataset_dict,
                        0,
                        model,
                        splits_to_eval=splits_to_eval)
            self.assertTrue(tester.all_accuracies["train"]
                            ["precision_at_1_level0"] == correct_vals["train"])
            self.assertTrue(tester.all_accuracies["val"]
                            ["precision_at_1_level0"] == correct_vals["val"])
    def test_accuracy_calculator(self):
        query_labels = np.array([1, 1, 2, 3, 4])

        knn_labels1 = np.array([
            [0, 1, 1, 2, 2],
            [1, 0, 1, 1, 3],
            [4, 4, 4, 4, 2],
            [3, 1, 3, 1, 3],
            [0, 0, 4, 2, 2],
        ])
        label_counts1 = {1: 3, 2: 5, 3: 4, 4: 5}

        knn_labels2 = knn_labels1 + 5
        label_counts2 = {k + 5: v for k, v in label_counts1.items()}

        for avg_of_avgs in [False, True]:
            for i, (knn_labels,
                    label_counts) in enumerate([(knn_labels1, label_counts1),
                                                (knn_labels2, label_counts2)]):

                AC = accuracy_calculator.AccuracyCalculator(
                    exclude=("NMI", "AMI"), avg_of_avgs=avg_of_avgs)
                kwargs = {
                    "query_labels":
                    query_labels,
                    "label_counts":
                    label_counts,
                    "knn_labels":
                    knn_labels,
                    "not_lone_query_mask":
                    np.ones(5).astype(np.bool)
                    if i == 0 else np.zeros(5).astype(np.bool),
                }

                function_dict = AC.get_function_dict()

                for ecfss in [False, True]:
                    if ecfss:
                        kwargs["knn_labels"] = kwargs["knn_labels"][:, 1:]
                    kwargs["embeddings_come_from_same_source"] = ecfss
                    acc = AC._get_accuracy(function_dict, **kwargs)
                    if i == 1:
                        self.assertTrue(acc["precision_at_1"] == 0)
                        self.assertTrue(acc["r_precision"] == 0)
                        self.assertTrue(
                            acc["mean_average_precision_at_r"] == 0)
                    else:
                        self.assertTrue(
                            acc["precision_at_1"] ==
                            self.correct_precision_at_1(ecfss, avg_of_avgs))
                        self.assertTrue(
                            acc["r_precision"] == self.correct_r_precision(
                                ecfss, avg_of_avgs))
                        self.assertTrue(
                            acc["mean_average_precision_at_r"] ==
                            self.correct_mean_average_precision_at_r(
                                ecfss, avg_of_avgs))
    def test_accuracy_calculator_and_faiss_avg_of_avgs(self):
        AC_global_average = accuracy_calculator.AccuracyCalculator(
            avg_of_avgs=False)
        AC_per_class_average = accuracy_calculator.AccuracyCalculator(
            avg_of_avgs=True)
        query = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
        reference = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
        query[-1] = 100
        reference[0] = -100
        query_labels = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
                                    device=TEST_DEVICE)
        reference_labels = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                        device=TEST_DEVICE)
        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.9))
        self.assertTrue(isclose(acc["r_precision"], 0.9))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.9))

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.5))
        self.assertTrue(isclose(acc["r_precision"], 0.5))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.5))
    def test_pca(self):
        # just make sure pca runs without crashing
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(include=("precision_at_1",))
        embeddings = torch.randn(1024, 512)
        labels = torch.randint(0, 10, size=(1024,))
        dataset_dict = {"train": c_f.EmbeddingDataset(embeddings, labels)}
        pca_size = 16

        def end_of_testing_hook(tester):
            self.assertTrue(
                tester.embeddings_and_labels["train"][0].shape[1] == pca_size
            )

        tester = GlobalEmbeddingSpaceTester(
            pca=pca_size,
            accuracy_calculator=AC,
            end_of_testing_hook=end_of_testing_hook,
        )
        all_accuracies = tester.test(dataset_dict, 0, model)
        self.assertTrue(not hasattr(tester, "embeddings_and_labels"))
    def test_with_same_parent_label_tester(self):
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(
            include=("precision_at_1", ))

        correct = [
            (None, {
                "train": 1,
                "val": (1 + 0.5) / 2
            }),
            (
                [("train", ["train", "val"]), ("val", ["train", "val"])],
                {
                    "train": (0.75 + 0.5) / 2,
                    "val": (0.4 + 0.75) / 2,
                },
            ),
            (
                [("train", ["train"]), ("val", ["train"])],
                {
                    "train": 1,
                    "val": (1 + 0.75) / 2
                },
            ),
        ]

        for splits_to_eval, correct_vals in correct:
            tester = WithSameParentLabelTester(accuracy_calculator=AC)
            all_accuracies = tester.test(self.dataset_dict,
                                         0,
                                         model,
                                         splits_to_eval=splits_to_eval)
            self.assertTrue(
                np.isclose(
                    all_accuracies["train"]["precision_at_1_level0"],
                    correct_vals["train"],
                ))
            self.assertTrue(
                np.isclose(all_accuracies["val"]["precision_at_1_level0"],
                           correct_vals["val"]))
示例#14
0
    def test_global_two_stream_embedding_space_tester(self):
        embedding_angles = [0, 10, 20, 30, 50, 60, 70, 80]
        embeddings1 = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles])
        embedding_angles = [81, 71, 61, 31, 51, 21, 11, 1]
        embeddings2 = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles])
        labels = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
        dataset_dict = {
            "train": FakeDataset(embeddings1, embeddings2, labels),
        }

        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(
            include=("precision_at_1", ))

        tester = GlobalTwoStreamEmbeddingSpaceTester(accuracy_calculator=AC,
                                                     dataloader_num_workers=0)
        all_accuracies = tester.test(dataset_dict, 0, model)

        self.assertTrue(
            np.isclose(all_accuracies["train"]["precision_at_1_level0"], 0.25))
示例#15
0
    def test_accuracy_calculator_and_faiss(self):
        AC = accuracy_calculator.AccuracyCalculator(exclude=("NMI", "AMI"))

        query = np.arange(10)[:, None].astype(np.float32)
        reference = np.arange(10)[:, None].astype(np.float32)
        query_labels = np.arange(10).astype(np.int)
        reference_labels = np.arange(10).astype(np.int)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels, False)
        self.assertTrue(acc["precision_at_1"] == 1)
        self.assertTrue(acc["r_precision"] == 1)
        self.assertTrue(acc["mean_average_precision_at_r"] == 1)

        reference = (np.arange(20) / 2.0)[:, None].astype(np.float32)
        reference_labels = np.zeros(20).astype(np.int)
        reference_labels[::2] = query_labels
        reference_labels[1::2] = np.ones(10).astype(np.int)
        acc = AC.get_accuracy(query, reference, query_labels, reference_labels, True)
        self.assertTrue(acc["precision_at_1"] == 1)
        self.assertTrue(acc["r_precision"] == 0.5)
        self.assertTrue(
            acc["mean_average_precision_at_r"]
            == (1 + 2.0 / 2 + 3.0 / 5 + 4.0 / 7 + 5.0 / 9) / 10
        )
示例#16
0
 def test_valid_k(self):
     for k in [-1, 0, 1.5, "max"]:
         self.assertRaises(
             ValueError,
             lambda: accuracy_calculator.AccuracyCalculator(k=k))
示例#17
0
    def _test_accuracy_calculator_custom_comparison_function(self, use_numpy):
        def label_comparison_fn(x, y):
            return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

        self.assertRaises(
            NotImplementedError,
            lambda: accuracy_calculator.AccuracyCalculator(
                include=("NMI", "AMI"),
                avg_of_avgs=False,
                label_comparison_fn=label_comparison_fn,
            ),
        )

        AC_global_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"),
            avg_of_avgs=False,
            label_comparison_fn=label_comparison_fn,
        )

        AC_per_class_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"),
            avg_of_avgs=True,
            label_comparison_fn=label_comparison_fn,
        )

        query_labels = [
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (1, 2),
        ]

        reference_labels = [
            (1, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
        ]

        if use_numpy:
            query = np.arange(10)[:, None]
            reference = np.arange(10)[:, None]
            query_labels = np.array(query_labels)
            reference_labels = np.array(reference_labels)
        else:
            query = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
            reference = torch.arange(10, device=TEST_DEVICE).unsqueeze(1)
            query_labels = torch.tensor(query_labels, device=TEST_DEVICE)
            reference_labels = torch.tensor(reference_labels,
                                            device=TEST_DEVICE)

        query[-1] = 100
        reference[0] = -100

        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.9))
        self.assertTrue(isclose(acc["r_precision"], 0.9))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.9))

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)
        self.assertTrue(isclose(acc["precision_at_1"], 0.5))
        self.assertTrue(isclose(acc["r_precision"], 0.5))
        self.assertTrue(isclose(acc["mean_average_precision_at_r"], 0.5))

        query_labels = [
            (1, 3),
            (7, 3),
        ]

        reference_labels = [
            (1, 3),
            (7, 4),
            (1, 4),
            (1, 5),
            (1, 6),
        ]

        # SIMPLE CASE
        if use_numpy:
            query = np.arange(2)[:, None]
            reference = np.arange(5)[:, None]
            query_labels = np.array(query_labels)
            reference_labels = np.array(reference_labels)
        else:
            query = torch.arange(2, device=TEST_DEVICE).unsqueeze(1)
            reference = torch.arange(5, device=TEST_DEVICE).unsqueeze(1)
            query_labels = torch.tensor(query_labels, device=TEST_DEVICE)
            reference_labels = torch.tensor(reference_labels,
                                            device=TEST_DEVICE)

        correct = {
            "precision_at_1": 0.5,
            "r_precision": (1.0 / 3 + 1) / 2,
            "mean_average_precision_at_r": ((1.0 / 3) / 3 + 1) / 2,
        }

        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)

        for k in correct:
            self.assertTrue(isclose(acc[k], correct[k]))

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)

        for k in correct:
            self.assertTrue(isclose(acc[k], correct[k]))
示例#18
0
    def test_accuracy_calculator(self):
        query_labels = torch.tensor([1, 1, 2, 3, 4], device=TEST_DEVICE)

        knn_labels1 = torch.tensor(
            [
                [0, 1, 1, 2, 2],
                [1, 0, 1, 1, 3],
                [4, 4, 4, 4, 2],
                [3, 1, 3, 1, 3],
                [0, 0, 4, 2, 2],
            ],
            device=TEST_DEVICE,
        )
        label_counts1 = ([1, 2, 3, 4], [3, 5, 4, 5])

        knn_labels2 = knn_labels1 + 5
        label_counts2 = ([6, 7, 8, 9], [3, 5, 4, 5])

        for avg_of_avgs in [False, True]:
            for i, (knn_labels,
                    label_counts) in enumerate([(knn_labels1, label_counts1),
                                                (knn_labels2, label_counts2)]):

                AC = accuracy_calculator.AccuracyCalculator(
                    exclude=("NMI", "AMI"), avg_of_avgs=avg_of_avgs)
                kwargs = {
                    "query_labels":
                    query_labels,
                    "label_counts":
                    label_counts,
                    "knn_labels":
                    knn_labels,
                    "not_lone_query_mask":
                    torch.ones(5, dtype=torch.bool)
                    if i == 0 else torch.zeros(5, dtype=torch.bool),
                }

                function_dict = AC.get_function_dict()

                for ecfss in [False, True]:
                    if ecfss:
                        kwargs["knn_labels"] = kwargs["knn_labels"][:, 1:]
                    kwargs["embeddings_come_from_same_source"] = ecfss
                    acc = AC._get_accuracy(function_dict, **kwargs)
                    if i == 1:
                        self.assertTrue(acc["precision_at_1"] == 0)
                        self.assertTrue(acc["r_precision"] == 0)
                        self.assertTrue(
                            acc["mean_average_precision_at_r"] == 0)
                        self.assertTrue(acc["mean_average_precision"] == 0)
                    else:
                        self.assertTrue(
                            isclose(
                                acc["precision_at_1"],
                                self.correct_precision_at_1(
                                    ecfss, avg_of_avgs),
                            ))
                        self.assertTrue(
                            isclose(
                                acc["r_precision"],
                                self.correct_r_precision(ecfss, avg_of_avgs),
                            ))
                        self.assertTrue(
                            isclose(
                                acc["mean_average_precision_at_r"],
                                self.correct_mean_average_precision_at_r(
                                    ecfss, avg_of_avgs),
                            ))
                        self.assertTrue(
                            isclose(
                                acc["mean_average_precision"],
                                self.correct_mean_average_precision(
                                    ecfss, avg_of_avgs),
                            ))
示例#19
0
    def test_metric_loss_only(self):

        cifar_resnet_folder = "temp_cifar_resnet_for_pytorch_metric_learning_test"
        dataset_folder = "temp_dataset_for_pytorch_metric_learning_test"
        model_folder = "temp_saved_models_for_pytorch_metric_learning_test"
        logs_folder = "temp_logs_for_pytorch_metric_learning_test"
        tensorboard_folder = "temp_tensorboard_for_pytorch_metric_learning_test"

        os.system(
            "git clone https://github.com/akamaster/pytorch_resnet_cifar10.git {}"
            .format(cifar_resnet_folder))

        loss_fn = NTXentLoss()

        normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                   std=[0.229, 0.224, 0.225])

        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize_transform,
        ])

        eval_transform = transforms.Compose(
            [transforms.ToTensor(), normalize_transform])

        assert not os.path.isdir(dataset_folder)
        assert not os.path.isdir(model_folder)
        assert not os.path.isdir(logs_folder)
        assert not os.path.isdir(tensorboard_folder)

        subset_idx = np.arange(10000)

        train_dataset = datasets.CIFAR100(dataset_folder,
                                          train=True,
                                          download=True,
                                          transform=train_transform)

        train_dataset_for_eval = datasets.CIFAR100(dataset_folder,
                                                   train=True,
                                                   download=True,
                                                   transform=eval_transform)

        val_dataset = datasets.CIFAR100(dataset_folder,
                                        train=False,
                                        download=True,
                                        transform=eval_transform)

        train_dataset = torch.utils.data.Subset(train_dataset, subset_idx)
        train_dataset_for_eval = torch.utils.data.Subset(
            train_dataset_for_eval, subset_idx)
        val_dataset = torch.utils.data.Subset(val_dataset, subset_idx)

        for dtype in TEST_DTYPES:
            for splits_to_eval in [
                    None,
                [("train", ["train", "val"]), ("val", ["train", "val"])],
            ]:
                from temp_cifar_resnet_for_pytorch_metric_learning_test import resnet

                model = torch.nn.DataParallel(resnet.resnet20())
                checkpoint = torch.load(
                    "{}/pretrained_models/resnet20-12fca82f.th".format(
                        cifar_resnet_folder),
                    map_location=TEST_DEVICE,
                )
                model.load_state_dict(checkpoint["state_dict"])
                model.module.linear = c_f.Identity()
                if TEST_DEVICE == torch.device("cpu"):
                    model = model.module
                model = model.to(TEST_DEVICE).type(dtype)

                optimizer = torch.optim.Adam(
                    model.parameters(),
                    lr=0.0002,
                    weight_decay=0.0001,
                    eps=1e-04,
                )

                batch_size = 32
                iterations_per_epoch = None if splits_to_eval is None else 1
                model_dict = {"trunk": model}
                optimizer_dict = {"trunk_optimizer": optimizer}
                loss_fn_dict = {"metric_loss": loss_fn}
                sampler = MPerClassSampler(
                    np.array(train_dataset.dataset.targets)[subset_idx],
                    m=4,
                    batch_size=32,
                    length_before_new_iter=len(train_dataset),
                )

                record_keeper, _, _ = logging_presets.get_record_keeper(
                    logs_folder, tensorboard_folder)
                hooks = logging_presets.get_hook_container(
                    record_keeper, primary_metric="precision_at_1")
                dataset_dict = {
                    "train": train_dataset_for_eval,
                    "val": val_dataset
                }

                tester = GlobalEmbeddingSpaceTester(
                    end_of_testing_hook=hooks.end_of_testing_hook,
                    accuracy_calculator=accuracy_calculator.AccuracyCalculator(
                        include=("precision_at_1", "AMI"), k=1),
                    data_device=TEST_DEVICE,
                    dtype=dtype,
                    dataloader_num_workers=32,
                )

                end_of_epoch_hook = hooks.end_of_epoch_hook(
                    tester,
                    dataset_dict,
                    model_folder,
                    test_interval=1,
                    patience=1,
                    splits_to_eval=splits_to_eval,
                )

                trainer = MetricLossOnly(
                    models=model_dict,
                    optimizers=optimizer_dict,
                    batch_size=batch_size,
                    loss_funcs=loss_fn_dict,
                    mining_funcs={},
                    dataset=train_dataset,
                    sampler=sampler,
                    data_device=TEST_DEVICE,
                    dtype=dtype,
                    dataloader_num_workers=32,
                    iterations_per_epoch=iterations_per_epoch,
                    freeze_trunk_batchnorm=True,
                    end_of_iteration_hook=hooks.end_of_iteration_hook,
                    end_of_epoch_hook=end_of_epoch_hook,
                )

                num_epochs = 3
                trainer.train(num_epochs=num_epochs)
                best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy(
                    tester, "val")
                if splits_to_eval is None:
                    self.assertTrue(best_epoch == 3)
                    self.assertTrue(best_accuracy > 0.2)

                accuracies, primary_metric_key = hooks.get_accuracies_of_best_epoch(
                    tester, "val")
                accuracies = c_f.sqliteObjToDict(accuracies)
                self.assertTrue(
                    accuracies[primary_metric_key][0] == best_accuracy)
                self.assertTrue(primary_metric_key == "precision_at_1_level0")

                best_epoch_accuracies = hooks.get_accuracies_of_epoch(
                    tester, "val", best_epoch)
                best_epoch_accuracies = c_f.sqliteObjToDict(
                    best_epoch_accuracies)
                self.assertTrue(best_epoch_accuracies[primary_metric_key][0] ==
                                best_accuracy)

                accuracy_history = hooks.get_accuracy_history(tester, "val")
                self.assertTrue(accuracy_history[primary_metric_key][
                    accuracy_history["epoch"].index(best_epoch)] ==
                                best_accuracy)

                loss_history = hooks.get_loss_history()
                if splits_to_eval is None:
                    self.assertTrue(
                        len(loss_history["metric_loss"]) == (len(sampler) /
                                                             batch_size) *
                        num_epochs)

                curr_primary_metric = hooks.get_curr_primary_metric(
                    tester, "val")
                self.assertTrue(curr_primary_metric ==
                                accuracy_history[primary_metric_key][-1])

                base_record_group_name = hooks.base_record_group_name(tester)

                self.assertTrue(
                    base_record_group_name ==
                    "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0")

                record_group_name = hooks.record_group_name(tester, "val")

                if splits_to_eval is None:
                    self.assertTrue(
                        record_group_name ==
                        "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_self"
                    )
                else:
                    self.assertTrue(
                        record_group_name ==
                        "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_TRAIN_and_VAL"
                    )

                shutil.rmtree(model_folder)
                shutil.rmtree(logs_folder)
                shutil.rmtree(tensorboard_folder)

        shutil.rmtree(cifar_resnet_folder)
        shutil.rmtree(dataset_folder)
示例#20
0
    def test_accuracy_calculator_custom_comparison_function(self):
        def label_comparison_fn(x, y):
            return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

        self.assertRaises(
            NotImplementedError,
            lambda: accuracy_calculator.AccuracyCalculator(
                include=("NMI", "AMI"),
                avg_of_avgs=False,
                label_comparison_fn=label_comparison_fn,
            ),
        )

        AC_global_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"),
            avg_of_avgs=False,
            label_comparison_fn=label_comparison_fn,
        )

        AC_per_class_average = accuracy_calculator.AccuracyCalculator(
            exclude=("NMI", "AMI"),
            avg_of_avgs=True,
            label_comparison_fn=label_comparison_fn,
        )

        query = np.arange(10)[:, None].astype(np.float32)
        reference = np.arange(10)[:, None].astype(np.float32)
        query[-1] = 100
        reference[0] = -100
        query_labels = np.array([
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (0, 2),
            (1, 2),
        ])
        reference_labels = np.array([
            (1, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
            (0, 3),
        ])
        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        self.assertTrue(acc["precision_at_1"] == 0.9)
        self.assertTrue(acc["r_precision"] == 0.9)
        self.assertTrue(acc["mean_average_precision_at_r"] == 0.9)

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)
        self.assertTrue(acc["precision_at_1"] == 0.5)
        self.assertTrue(acc["r_precision"] == 0.5)
        self.assertTrue(acc["mean_average_precision_at_r"] == 0.5)

        # SIMPLE CASE
        query = np.arange(2)[:, None].astype(np.float32)
        reference = np.arange(5)[:, None].astype(np.float32)
        query_labels = np.array([
            (1, 3),
            (7, 3),
        ])
        reference_labels = np.array([
            (1, 3),
            (7, 4),
            (1, 4),
            (1, 5),
            (1, 6),
        ])

        correct_precision_at_1 = 0.5
        correct_r_precision = (1.0 / 3 + 1) / 2
        correct_mapr = ((1.0 / 3) / 3 + 1) / 2

        acc = AC_global_average.get_accuracy(query, reference, query_labels,
                                             reference_labels, False)
        self.assertTrue(acc["precision_at_1"] == correct_precision_at_1)
        self.assertTrue(acc["r_precision"] == correct_r_precision)
        self.assertTrue(acc["mean_average_precision_at_r"] == correct_mapr)

        acc = AC_per_class_average.get_accuracy(query, reference, query_labels,
                                                reference_labels, False)
        self.assertTrue(acc["precision_at_1"] == correct_precision_at_1)
        self.assertTrue(acc["r_precision"] == correct_r_precision)
        self.assertTrue(acc["mean_average_precision_at_r"] == correct_mapr)