Ejemplo n.º 1
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
Ejemplo n.º 2
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)
Ejemplo n.º 3
0
    def setUp(self):
        setup_imports()
        replace_with_jit()
        model_name = "visual_bert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        model_class = registry.get_model_class(model_name)
        self.pretrain_model = model_class(config.model_config[model_name])
        self.pretrain_model.build()

        config.model_config[model_name][
            "training_head_type"] = "classification"
        config.model_config[model_name]["num_labels"] = 2
        self.finetune_model = model_class(config.model_config[model_name])
        self.finetune_model.build()
Ejemplo n.º 4
0
 def __init__(self, config, *args, **kwargs):
     super().__init__(config, *args, **kwargs)
     # Replace transformer layers with scriptable JIT layers
     replace_with_jit()
Ejemplo n.º 5
0
 def test_undo_replace_with_jit(self):
     original_function = BertSelfAttention.forward
     replace_with_jit()
     undo_replace_with_jit()
     self.assertTrue(BertSelfAttention.forward is original_function)