def compare_torchscript_transformer_models(model, vocab_size):
    test_sample = Sample()
    test_sample.input_ids = torch.randint(low=0, high=vocab_size,
                                          size=(128, )).long()
    test_sample.input_mask = torch.ones(128).long()
    test_sample.segment_ids = torch.zeros(128).long()
    test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
    test_sample.image = torch.rand((3, 300, 300)).float()
    test_sample_list = SampleList([test_sample])

    with torch.no_grad():
        model_output = model(test_sample_list)

    script_model = torch.jit.script(model)
    with torch.no_grad():
        script_output = script_model(test_sample_list)

    return torch.equal(model_output["scores"], script_output["scores"])
    def test_finetune_model(self):
        self.finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255,
                                              size=(128, )).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = self.finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(self.finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
    def test_modal_end_token(self):
        self.finetune_model.eval()

        # Suppose 0 for <cls>, 1 for <pad> 2 for <sep>
        CLS = 0
        PAD = 1
        SEP = 2
        size = 128

        input_ids = torch.randint(low=0, high=30255, size=(size, )).long()
        input_mask = torch.ones(size).long()

        input_ids[0] = CLS
        length = torch.randint(low=2, high=size - 1, size=(1, ))
        input_ids[length] = SEP
        input_ids[length + 1:] = PAD
        input_mask[length + 1:] = 0

        test_sample = Sample()
        test_sample.input_ids = input_ids.clone()
        test_sample.input_mask = input_mask.clone()
        test_sample.segment_ids = torch.zeros(size).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample])

        mmbt_base = self.finetune_model.model.bert
        with torch.no_grad():
            actual_modal_end_token = mmbt_base.extract_modal_end_token(
                test_sample_list)

        expected_modal_end_token = torch.zeros([1]).fill_(SEP).long()
        self.assertTrue(
            torch.equal(actual_modal_end_token, expected_modal_end_token))
        self.assertTrue(
            torch.equal(test_sample_list.input_ids[0, :-1], input_ids[1:]))
        self.assertTrue(
            torch.equal(test_sample_list.input_mask[0, :-1], input_mask[1:]))