Exemplo n.º 1
0
    def test_inference_vqa(self):
        model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")

        input_ids = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.long).reshape(1, -1)
        token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long).reshape(1, -1)
        visual_embeds = torch.ones(size=(1, 10, 2048), dtype=torch.float32) * 0.5
        visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long)
        attention_mask = torch.tensor([1] * 6).reshape(1, -1)
        visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)

        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            visual_embeds=visual_embeds,
            visual_attention_mask=visual_attention_mask,
            visual_token_type_ids=visual_token_type_ids,
        )

        # vocab_size = 30522

        expected_shape = torch.Size((1, 3129))
        self.assertEqual(output.logits.shape, expected_shape)

        expected_slice = torch.tensor(
            [[-8.9898, 3.0803, -1.8016, 2.4542, -8.3420, -2.0224, -3.3124, -4.4139, -3.1491, -3.8997]]
        )

        self.assertTrue(torch.allclose(output.logits[:, :10], expected_slice, atol=1e-4))
def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our VisualBERT structure.
    """

    assert (
        checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS
    ), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}."

    # Get Config
    if "pre" in checkpoint_path:
        model_type = "pretraining"
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "nlvr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 1024}
        else:
            raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.")
    else:
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
            model_type = "multichoice"
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
            model_type = "vqa_advanced"
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048, "num_labels": 3129}
            model_type = "vqa"
        elif "nlvr" in checkpoint_path:
            config_params = {
                "visual_embedding_dim": 1024,
                "num_labels": 2,
            }
            model_type = "nlvr"

    config = VisualBertConfig(**config_params)

    # Load State Dict
    state_dict = load_state_dict(checkpoint_path)

    new_state_dict = get_new_dict(state_dict, config)

    if model_type == "pretraining":
        model = VisualBertForPreTraining(config)
    elif model_type == "vqa":
        model = VisualBertForQuestionAnswering(config)
    elif model_type == "nlvr":
        model = VisualBertForVisualReasoning(config)
    elif model_type == "multichoice":
        model = VisualBertForMultipleChoice(config)

    model.load_state_dict(new_state_dict)
    # Save Checkpoints
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    model.save_pretrained(pytorch_dump_folder_path)
 def create_and_check_for_vqa(self, config, input_dict):
     model = VisualBertForQuestionAnswering(config=config)
     model.to(torch_device)
     model.eval()
     result = model(**input_dict)
     self.parent.assertEqual(result.logits.shape,
                             (self.batch_size, self.num_labels))