def compare_torchscript_transformer_models(model, vocab_size): test_sample = Sample() test_sample.input_ids = torch.randint(low=0, high=vocab_size, size=(128, )).long() test_sample.input_mask = torch.ones(128).long() test_sample.segment_ids = torch.zeros(128).long() test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float() test_sample.image = torch.rand((3, 300, 300)).float() test_sample_list = SampleList([test_sample]) with torch.no_grad(): model_output = model(test_sample_list) script_model = torch.jit.script(model) with torch.no_grad(): script_output = script_model(test_sample_list) return torch.equal(model_output["scores"], script_output["scores"])
def test_finetune_model(self): self.finetune_model.eval() test_sample = Sample() test_sample.input_ids = torch.randint(low=0, high=30255, size=(128, )).long() test_sample.input_mask = torch.ones(128).long() test_sample.segment_ids = torch.zeros(128).long() test_sample.image = torch.rand((3, 300, 300)).float() test_sample_list = SampleList([test_sample.copy()]) with torch.no_grad(): model_output = self.finetune_model.model(test_sample_list) test_sample_list = SampleList([test_sample]) script_model = torch.jit.script(self.finetune_model.model) with torch.no_grad(): script_output = script_model(test_sample_list) self.assertTrue( torch.equal(model_output["scores"], script_output["scores"]))
def test_modal_end_token(self): self.finetune_model.eval() # Suppose 0 for <cls>, 1 for <pad> 2 for <sep> CLS = 0 PAD = 1 SEP = 2 size = 128 input_ids = torch.randint(low=0, high=30255, size=(size, )).long() input_mask = torch.ones(size).long() input_ids[0] = CLS length = torch.randint(low=2, high=size - 1, size=(1, )) input_ids[length] = SEP input_ids[length + 1:] = PAD input_mask[length + 1:] = 0 test_sample = Sample() test_sample.input_ids = input_ids.clone() test_sample.input_mask = input_mask.clone() test_sample.segment_ids = torch.zeros(size).long() test_sample.image = torch.rand((3, 300, 300)).float() test_sample_list = SampleList([test_sample]) mmbt_base = self.finetune_model.model.bert with torch.no_grad(): actual_modal_end_token = mmbt_base.extract_modal_end_token( test_sample_list) expected_modal_end_token = torch.zeros([1]).fill_(SEP).long() self.assertTrue( torch.equal(actual_modal_end_token, expected_modal_end_token)) self.assertTrue( torch.equal(test_sample_list.input_ids[0, :-1], input_ids[1:])) self.assertTrue( torch.equal(test_sample_list.input_mask[0, :-1], input_mask[1:]))