コード例 #1
0
 def test_finetune_xlmr_base(self):
     self.config.model_config[
         self.model_name]["transformer_base"] = "xlm-roberta-base"
     model = build_model(self.config.model_config[self.model_name])
     model.eval()
     self.assertTrue(
         test_utils.compare_torchscript_transformer_models(
             model, vocab_size=XLM_ROBERTA_VOCAB_SIZE))
コード例 #2
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "multimodelity_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
     self.finetune_model = build_model(
         self.config.model_config[self.model_name])
コード例 #3
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
コード例 #4
0
    def load_model(self):
        logger.info("Loading model")
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model = self.model.to(self.device)
コード例 #5
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vilbert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        self.vision_feature_size = 1024
        self.vision_target_size = 1279
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "pretraining"
        model_config["visual_embedding_dim"] = self.vision_feature_size
        model_config["v_feature_size"] = self.vision_feature_size
        model_config["v_target_size"] = self.vision_target_size
        model_config["dynamic_attention"] = False
        model_config.model = model_name
        self.pretrain_model = build_model(model_config)

        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        self.finetune_model = build_model(model_config)
コード例 #6
0
ファイル: test_mmbt.py プロジェクト: hahaxun/mmf
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)