コード例 #1
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilbert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     self.vision_feature_size = 1024
     self.vision_target_size = 1279
     config.model_config[model_name]["training_head_type"] = "pretraining"
     config.model_config[model_name][
         "visual_embedding_dim"] = self.vision_feature_size
     config.model_config[model_name][
         "v_feature_size"] = self.vision_feature_size
     config.model_config[model_name][
         "v_target_size"] = self.vision_target_size
     config.model_config[model_name]["dynamic_attention"] = False
     self.pretrain_model = model_class(config.model_config[model_name])
     self.pretrain_model.build()
     config.model_config[model_name][
         "training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
コード例 #2
0
    def test_model_configs_for_keys(self):
        models_mapping = registry.mapping["model_name_mapping"]

        for model_key, model_cls in models_mapping.items():
            if model_cls.config_path() is None:
                warnings.warn(
                    ("Model {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(model_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(model=model_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()

                if model_key == "mmft":
                    continue

                self.assertTrue(
                    model_key in config.model_config,
                    "Key for model {} doesn't exists in its configuration".
                    format(model_key),
                )
コード例 #3
0
ファイル: test_vinvl.py プロジェクト: facebookresearch/mmf
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vinvl"
        args = test_utils.dummy_args(model=model_name, dataset="test")
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config.model = model_name
        model_config.do_pretraining = False
        classification_config_dict = {
            "do_pretraining": False,
            "heads": {"mlp": {"num_labels": 3129}},
            "ce_loss": {"ignore_index": -1},
        }
        self.classification_config = OmegaConf.create(
            {**model_config, **classification_config_dict}
        )

        pretraining_config_dict = {
            "do_pretraining": True,
            "heads": {"mlm": {"hidden_size": 768}},
        }
        self.pretraining_config = OmegaConf.create(
            {**model_config, **pretraining_config_dict}
        )

        self.sample_list = self._get_sample_list()
コード例 #4
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
コード例 #5
0
ファイル: test_vilt.py プロジェクト: facebookresearch/mmf
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilt"
     args = test_utils.dummy_args(model=model_name, dataset="test")
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
コード例 #6
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "multimodelity_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
     self.finetune_model = build_model(
         self.config.model_config[self.model_name])
コード例 #7
0
 def setUp(self):
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.model_class = registry.get_model_class(self.model_name)
     self.finetune_model = self.model_class(
         self.config.model_config[self.model_name])
     self.finetune_model.build()
コード例 #8
0
ファイル: test_mmbt_script.py プロジェクト: tjulyz/mmf
 def setUp(self):
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     config.model_config[model_name]["training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
コード例 #9
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
コード例 #10
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)
コード例 #11
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "uniter"
        args = test_utils.dummy_args(model=model_name, dataset="vqa2")
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config.model = model_name
        model_config.losses = {"vqa2": "logit_bce"}
        model_config.do_pretraining = False
        model_config.tasks = "vqa2"
        classification_config_dict = {
            "do_pretraining": False,
            "tasks": "vqa2",
            "heads": {
                "vqa2": {
                    "type": "mlp",
                    "num_labels": 3129
                }
            },
            "losses": {
                "vqa2": "logit_bce"
            },
        }
        classification_config = OmegaConf.create({
            **model_config,
            **classification_config_dict
        })

        pretraining_config_dict = {
            "do_pretraining": True,
            "tasks": "wra",
            "heads": {
                "wra": {
                    "type": "wra"
                }
            },
        }
        pretraining_config = OmegaConf.create({
            **model_config,
            **pretraining_config_dict
        })

        self.model_for_classification = build_model(classification_config)
        self.model_for_pretraining = build_model(pretraining_config)
コード例 #12
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #13
0
ファイル: test_vilbert.py プロジェクト: zhang703652632/mmf
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vilbert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        self.vision_feature_size = 1024
        self.vision_target_size = 1279
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "pretraining"
        model_config["visual_embedding_dim"] = self.vision_feature_size
        model_config["v_feature_size"] = self.vision_feature_size
        model_config["v_target_size"] = self.vision_target_size
        model_config["dynamic_attention"] = False
        model_config.model = model_name

        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        self.model_config = model_config
コード例 #14
0
ファイル: test_cnn_lstm.py プロジェクト: vishalbelsare/pythia
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #15
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #16
0
    def test_dataset_configs_for_keys(self):
        builder_name = registry.mapping["builder_name_mapping"]

        for builder_key, builder_cls in builder_name.items():
            if builder_cls.config_path() is None:
                warnings.warn(
                    ("Dataset {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(builder_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(dataset=builder_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()
                self.assertTrue(
                    builder_key in config.dataset_config,
                    "Key for dataset {} doesn't exists in its configuration".
                    format(builder_key),
                )
コード例 #17
0
 def test_config_overrides(self):
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "m4c",
         "configs",
         "textvqa",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="m4c", dataset="textvqa")
     args.opts += [
         f"config={config_path}", "training.lr_steps[1]=10000",
         "dataset_config.textvqa.zoo_requirements[0]=\"test\""
     ]
     configuration = Configuration(args)
     configuration.freeze()
     config = configuration.get_config()
     self.assertEqual(config.training.lr_steps[1], 10000)
     self.assertEqual(config.dataset_config.textvqa.zoo_requirements[0],
                      "test")