def test_inference_nlvr(self): model = VisualBertForVisualReasoning.from_pretrained("uclanlp/visualbert-nlvr2") input_ids = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.long).reshape(1, -1) token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long).reshape(1, -1) visual_embeds = torch.ones(size=(1, 10, 1024), dtype=torch.float32) * 0.5 visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long) attention_mask = torch.tensor([1] * 6).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) output = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, ) # vocab_size = 30522 expected_shape = torch.Size((1, 2)) self.assertEqual(output.logits.shape, expected_shape) expected_slice = torch.tensor([[-1.1436, 0.8900]]) self.assertTrue(torch.allclose(output.logits, expected_slice, atol=1e-4))
def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our VisualBERT structure. """ assert ( checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS ), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}." # Get Config if "pre" in checkpoint_path: model_type = "pretraining" if "vcr" in checkpoint_path: config_params = {"visual_embedding_dim": 512} elif "vqa_advanced" in checkpoint_path: config_params = {"visual_embedding_dim": 2048} elif "vqa" in checkpoint_path: config_params = {"visual_embedding_dim": 2048} elif "nlvr" in checkpoint_path: config_params = {"visual_embedding_dim": 1024} else: raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.") else: if "vcr" in checkpoint_path: config_params = {"visual_embedding_dim": 512} model_type = "multichoice" elif "vqa_advanced" in checkpoint_path: config_params = {"visual_embedding_dim": 2048} model_type = "vqa_advanced" elif "vqa" in checkpoint_path: config_params = {"visual_embedding_dim": 2048, "num_labels": 3129} model_type = "vqa" elif "nlvr" in checkpoint_path: config_params = { "visual_embedding_dim": 1024, "num_labels": 2, } model_type = "nlvr" config = VisualBertConfig(**config_params) # Load State Dict state_dict = load_state_dict(checkpoint_path) new_state_dict = get_new_dict(state_dict, config) if model_type == "pretraining": model = VisualBertForPreTraining(config) elif model_type == "vqa": model = VisualBertForQuestionAnswering(config) elif model_type == "nlvr": model = VisualBertForVisualReasoning(config) elif model_type == "multichoice": model = VisualBertForMultipleChoice(config) model.load_state_dict(new_state_dict) # Save Checkpoints Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path)
def create_and_check_for_nlvr(self, config, input_dict): model = VisualBertForVisualReasoning(config=config) model.to(torch_device) model.eval() result = model(**input_dict) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))