def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our VisualBERT structure.
    """

    assert (
        checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS
    ), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}."

    # Get Config
    if "pre" in checkpoint_path:
        model_type = "pretraining"
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "nlvr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 1024}
        else:
            raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.")
    else:
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
            model_type = "multichoice"
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
            model_type = "vqa_advanced"
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048, "num_labels": 3129}
            model_type = "vqa"
        elif "nlvr" in checkpoint_path:
            config_params = {
                "visual_embedding_dim": 1024,
                "num_labels": 2,
            }
            model_type = "nlvr"

    config = VisualBertConfig(**config_params)

    # Load State Dict
    state_dict = load_state_dict(checkpoint_path)

    new_state_dict = get_new_dict(state_dict, config)

    if model_type == "pretraining":
        model = VisualBertForPreTraining(config)
    elif model_type == "vqa":
        model = VisualBertForQuestionAnswering(config)
    elif model_type == "nlvr":
        model = VisualBertForVisualReasoning(config)
    elif model_type == "multichoice":
        model = VisualBertForMultipleChoice(config)

    model.load_state_dict(new_state_dict)
    # Save Checkpoints
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    model.save_pretrained(pytorch_dump_folder_path)
Ejemplo n.º 2
0
 def create_and_check_for_pretraining(self, config, input_dict):
     model = VisualBertForPreTraining(config=config)
     model.to(torch_device)
     model.eval()
     result = model(**input_dict)
     self.parent.assertEqual(
         result.prediction_logits.shape,
         (self.batch_size, self.seq_length + self.visual_seq_length, self.vocab_size),
     )
    def test_inference_vqa_coco_pre(self):
        model = VisualBertForPreTraining.from_pretrained(
            "uclanlp/visualbert-vqa-coco-pre")

        input_ids = torch.tensor([1, 2, 3, 4, 5, 6],
                                 dtype=torch.long).reshape(1, -1)
        token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1],
                                      dtype=torch.long).reshape(1, -1)
        visual_embeds = torch.ones(size=(1, 10, 2048),
                                   dtype=torch.float32) * 0.5
        visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long)
        attention_mask = torch.tensor([1] * 6).reshape(1, -1)
        visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)

        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            visual_embeds=visual_embeds,
            visual_attention_mask=visual_attention_mask,
            visual_token_type_ids=visual_token_type_ids,
        )

        vocab_size = 30522

        expected_shape = torch.Size((1, 16, vocab_size))
        self.assertEqual(output.prediction_logits.shape, expected_shape)

        expected_slice = torch.tensor([[[-5.1858, -5.1903, -4.9142],
                                        [-6.2214, -5.9238, -5.8381],
                                        [-6.3027, -5.9939, -5.9297]]])

        self.assertTrue(
            torch.allclose(output.prediction_logits[:, :3, :3],
                           expected_slice,
                           atol=1e-4))

        expected_shape_2 = torch.Size((1, 2))
        self.assertEqual(output.seq_relationship_logits.shape,
                         expected_shape_2)

        expected_slice_2 = torch.tensor([[0.7393, 0.1754]])

        self.assertTrue(
            torch.allclose(output.seq_relationship_logits,
                           expected_slice_2,
                           atol=1e-4))