def load_model_for_eval(model_factory,
                        model_args,
                        model_name,
                        factory_kwargs,
                        model_folder=None,
                        device=None):
    untrained_trunk = model_name in const.UNTRAINED_TRUNK_ALIASES
    untrained_trunk_and_embedder = model_name in const.UNTRAINED_TRUNK_AND_EMBEDDER_ALIASES
    trunk_model = model_factory.create(named_specs=model_args,
                                       subset="trunk",
                                       **factory_kwargs)
    if untrained_trunk:
        embedder_model = pml_cf.Identity()
    else:
        embedder_model = model_factory.create(named_specs=model_args,
                                              subset="embedder",
                                              **factory_kwargs)
        if not untrained_trunk_and_embedder:
            if model_name in const.TRAINED_ALIASES:
                _, model_name = pml_cf.latest_version(model_folder, best=True)
            pml_cf.load_dict_of_models(
                {
                    "trunk": trunk_model,
                    "embedder": embedder_model
                },
                model_name,
                model_folder,
                device,
                log_if_successful=True,
                assert_success=True)
    return torch.nn.DataParallel(trunk_model), torch.nn.DataParallel(
        embedder_model)
Exemplo n.º 2
0
    def test_global_embedding_space_tester(self):
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(
            include=("precision_at_1", ))

        correct = [
            (None, {
                "train": 1,
                "val": 6.0 / 8
            }),
            (
                [("train", ["train", "val"]), ("val", ["train", "val"])],
                {
                    "train": 1.0 / 8,
                    "val": 1.0 / 8
                },
            ),
            ([("train", ["train"]), ("val", ["train"])], {
                "train": 1,
                "val": 1.0 / 8
            }),
        ]

        for splits_to_eval, correct_vals in correct:
            tester = GlobalEmbeddingSpaceTester(accuracy_calculator=AC)
            tester.test(self.dataset_dict,
                        0,
                        model,
                        splits_to_eval=splits_to_eval)
            self.assertTrue(tester.all_accuracies["train"]
                            ["precision_at_1_level0"] == correct_vals["train"])
            self.assertTrue(tester.all_accuracies["val"]
                            ["precision_at_1_level0"] == correct_vals["val"])
Exemplo n.º 3
0
def load_trunk_embedder(trunk_path, embedder_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set trunk model and replace the softmax layer with an identity function
    trunk = torchvision.models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = common_functions.Identity()
    trunk = torch.nn.DataParallel(trunk.to(device))

    # Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
    embedder = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))

    if device == 'cpu':
        trunk.load_state_dict(
            torch.load(trunk_path, map_location=torch.device('cpu')))
    else:
        trunk.load_state_dict(torch.load(trunk_path))
    embedder.load_state_dict(
        torch.load(embedder_path, map_location=torch.device('cpu')))

    return trunk, embedder
    def test_pca(self):
        # just make sure pca runs without crashing
        model = c_f.Identity()
        AC = accuracy_calculator.AccuracyCalculator(include=("precision_at_1",))
        embeddings = torch.randn(1024, 512)
        labels = torch.randint(0, 10, size=(1024,))
        dataset_dict = {"train": c_f.EmbeddingDataset(embeddings, labels)}
        pca_size = 16

        def end_of_testing_hook(tester):
            self.assertTrue(
                tester.embeddings_and_labels["train"][0].shape[1] == pca_size
            )

        tester = GlobalEmbeddingSpaceTester(
            pca=pca_size,
            accuracy_calculator=AC,
            end_of_testing_hook=end_of_testing_hook,
        )
        all_accuracies = tester.test(dataset_dict, 0, model)
        self.assertTrue(not hasattr(tester, "embeddings_and_labels"))