Beispiel #1
0
    def test_constructor_on_multiple_tasks(self):
        heads_conf = {
            "test": {
                "type": "mlp",
                "loss": "test_cls"
            },
            "other_task": {
                "type": "itm"
            },
            "third_task": {
                "type": "mlm"
            },
        }
        tasks = ["test", "other_task"]
        heads_dict = build_heads_dict(heads_conf, tasks, self.losses)
        self.assertTrue(isinstance(heads_dict, HeadsDict))

        # test forward
        task = "other_task"
        head_output = heads_dict.forward(task, self.model_output,
                                         self.sample_list)
        self.assertTrue(isinstance(head_output, dict))
        self.assertTrue("losses" in head_output)
        self.assertTrue(
            "test/test_dataset/logit_bce" not in head_output["losses"])
        self.assertTrue("itm_loss" in head_output["losses"])
Beispiel #2
0
    def test_constructor_on_list_confs(self):
        heads_conf = [{"type": "mlp", "loss": "test_cls"}]
        tasks = []
        heads_dict = build_heads_dict(heads_conf, tasks, self.losses)
        self.assertTrue(isinstance(heads_dict, HeadsDict))

        # test forward
        task = None
        head_output = heads_dict.forward(task, self.model_output,
                                         self.sample_list)
        self.assertTrue(isinstance(head_output, dict))
        self.assertTrue("losses" in head_output)
        self.assertTrue("test/test_dataset/logit_bce" in head_output["losses"])
Beispiel #3
0
    def build(self):
        self.text_embeddings = ViLTTextEmbedding(**self.config.text_embeddings)
        self.image_embeddings = ViLTImageEmbedding(
            **self.config.image_encoder.params)
        self.encoder = build_encoder(self.config.image_encoder)

        head_configs = self.config.get("heads", {})
        self.tasks = self.config.get("tasks", head_configs.keys())
        if isinstance(self.tasks, str):
            self.tasks = self.tasks.split(",")

        self.losses = nn.ModuleDict()
        self.heads_dict = build_heads_dict(head_configs, self.tasks,
                                           self.losses)
        self.modality_keys = self.modality_type = ["text", "image"]