示例#1
0
    def test_modality_key_preprocessing(self):
        self._text_modality_config.key = "body"
        second_text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="ocr",
            embedding_dim=756,
            position_dim=128,
            segment_id=2,
            encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            second_text_modality_config,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.body = torch.randint(0, 512, (2, 128))
        sample_list.ocr = torch.randint(0, 512, (2, 128))
        sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 128))
        lm_labels_sum = sample_list.lm_label_ids.sum().item() * 2

        transformer_input = mmft.preprocess_sample(sample_list)
        self._compare_processed_for_multimodality(transformer_input,
                                                  lm_labels_sum)
示例#2
0
    def test_finetune_model(self):
        self.model_config["training_head_type"] = "classification"
        finetune_model = build_model(self.model_config)
        finetune_model.model.eval()
        num_bbox_per_image = 10
        input_ids = torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long()
        attention_mask = torch.ones((1, 128)).long()
        token_type_ids = torch.zeros(1, 128).long()
        visual_embeddings = torch.rand(
            (1, num_bbox_per_image, self.vision_feature_size)
        ).float()
        image_attention_mask = torch.zeros((1, num_bbox_per_image)).long()
        visual_locations = torch.rand((1, num_bbox_per_image, 5)).float()
        finetune_model.eval()

        with torch.no_grad():
            model_output = finetune_model.model(
                input_ids=input_ids,
                image_feature=visual_embeddings,
                image_location=visual_locations,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                image_attention_mask=image_attention_mask,
            )
        script_model = torch.jit.script(finetune_model.model)
        with torch.no_grad():
            script_output = script_model(
                input_ids=input_ids,
                image_feature=visual_embeddings,
                image_location=visual_locations,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                image_attention_mask=image_attention_mask,
            )
        self.assertTrue(torch.equal(model_output["scores"], script_output["scores"]))
示例#3
0
 def test_mmft_from_build_model(self):
     modalities_config = [
         MMFTransformerModalityConfig(
             type="image",
             key="image",
             embedding_dim=256,
             position_dim=1,
             segment_id=0,
             encoder=ImageEncoderFactory.Config(
                 type=ImageEncoderTypes.resnet152,
                 params=ResNet152ImageEncoder.Config(pretrained=False),
             ),
         ),
         MMFTransformerModalityConfig(
             type="text",
             key="text",
             embedding_dim=756,
             position_dim=512,
             segment_id=1,
             encoder=TextEncoderFactory.Config(
                 type=TextEncoderTypes.identity),
         ),
     ]
     config = MMFTransformer.Config(modalities=modalities_config,
                                    num_labels=2)
     mmft = build_model(config)
     self.assertIsNotNone(mmft)
示例#4
0
    def load_model_and_optimizer(self):
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)

        if "cuda" in str(self.device):
            device_info = "CUDA Device {} is: {}".format(
                self.config.distributed.rank,
                torch.cuda.get_device_name(self.local_rank),
            )
            registry.register("global_device", self.config.distributed.rank)
            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)
        self.optimizer = build_optimizer(self.model, self.config)

        registry.register("data_parallel", False)
        registry.register("distributed", False)

        self.load_extras()
        self.parallelize_model()
    def test_tie_mlm_head_weight_to_encoder(self):
        self._text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="text",
            embedding_dim=768,
            position_dim=128,
            segment_id=0,
            encoder=TextEncoderFactory.Config(
                type=TextEncoderTypes.transformer),
        )
        heads = [MLM.Config()]
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(
            heads=heads,
            modalities=modalities_config,
            num_labels=2,
            tie_weight_to_encoder="text",
        )
        mmft = build_model(config)

        test_utils.compare_tensors(
            mmft.heads[0].cls.predictions.decoder.weight,
            mmft.encoders["text"].embeddings.word_embeddings.weight,
        )
示例#6
0
 def test_finetune_xlmr_base(self):
     self.config.model_config[
         self.model_name]["transformer_base"] = "xlm-roberta-base"
     model = build_model(self.config.model_config[self.model_name])
     model.eval()
     self.assertTrue(
         test_utils.compare_torchscript_transformer_models(
             model, vocab_size=XLM_ROBERTA_VOCAB_SIZE))
示例#7
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "uniter"
        args = test_utils.dummy_args(model=model_name, dataset="vqa2")
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config.model = model_name
        model_config.losses = {"vqa2": "logit_bce"}
        model_config.do_pretraining = False
        model_config.tasks = "vqa2"
        classification_config_dict = {
            "do_pretraining": False,
            "tasks": "vqa2",
            "heads": {
                "vqa2": {
                    "type": "mlp",
                    "num_labels": 3129
                }
            },
            "losses": {
                "vqa2": "logit_bce"
            },
        }
        classification_config = OmegaConf.create({
            **model_config,
            **classification_config_dict
        })

        pretraining_config_dict = {
            "do_pretraining": True,
            "tasks": "wra",
            "heads": {
                "wra": {
                    "type": "wra"
                }
            },
        }
        pretraining_config = OmegaConf.create({
            **model_config,
            **pretraining_config_dict
        })

        self.model_for_classification = build_model(classification_config)
        self.model_for_pretraining = build_model(pretraining_config)
示例#8
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
     self.finetune_model = build_model(self.config.model_config[self.model_name])
示例#9
0
    def test_vinvl_for_classification(self):
        model_for_classification = build_model(self.classification_config)
        model_for_classification.eval()
        model_for_classification = model_for_classification.to(get_current_device())
        with torch.no_grad():
            model_output = model_for_classification(self.sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("ce" in model_output["losses"])
示例#10
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilt"
     args = test_utils.dummy_args(model=model_name, dataset="test")
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
示例#11
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
示例#12
0
    def load_model(self) -> None:
        logger.info("Loading models")

        attributes = self.config.model_config[self.config.model]
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]
        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model.is_pl_enabled = True
示例#13
0
    def test_vinvl_for_pretraining(self):
        model_for_pretraining = build_model(self.pretraining_config)
        model_for_pretraining.eval()
        model_for_pretraining = model_for_pretraining.to(get_current_device())

        with torch.no_grad():
            model_output = model_for_pretraining(self.sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("masked_lm_loss" in model_output["losses"])
        self.assertTrue("three_way_contrastive_loss" in model_output["losses"])
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vilbert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        self.vision_feature_size = 1024
        self.vision_target_size = 1279
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "pretraining"
        model_config["visual_embedding_dim"] = self.vision_feature_size
        model_config["v_feature_size"] = self.vision_feature_size
        model_config["v_target_size"] = self.vision_target_size
        model_config["dynamic_attention"] = False
        model_config.model = model_name
        self.pretrain_model = build_model(model_config)

        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        self.finetune_model = build_model(model_config)
示例#15
0
    def load_model(self):
        logger.info("Loading model")
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model = self.model.to(self.device)
示例#16
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)
示例#17
0
    def _build_model(self):
        self.model_items = load_pretrained_model(self.checkpoint)
        self.config = OmegaConf.create(self.model_items["full_config"])
        dataset_name = list(self.config.dataset_config.keys())[0]
        processor = build_processors(
            self.config.dataset_config[dataset_name].processors)
        feature_extractor = build_encoder(
            self.model_items["config"].image_feature_encodings)
        ckpt = self.model_items["checkpoint"]
        model = build_model(self.model_items["config"])
        model.load_state_dict(ckpt)

        return processor, feature_extractor, model
示例#18
0
    def test_preprocessing_with_resnet_encoder(self):
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=2048,
            position_dim=1,
            segment_id=0,
            encoder=ImageEncoderFactory.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 3, 224, 224)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)

        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
示例#19
0
文件: mmf_trainer.py 项目: zyan97/mmf
    def load_model(self):
        logger.info("Loading model")
        # Set the config files for 
        model_list = ['multi_task_pair_wise_visual_bert', 'attn_based_multi_task_pair_wise_visual_bert', 'attn_based_pair_concat_visual_bert']
        if self.config.model in model_list:
            attributes = self.config.model_config['visual_bert']
        else:
          attributes = self.config.model_config[self.config.model]
          # Easy way to point to config for other model
          if isinstance(attributes, str):
              attributes = self.config.model_config[attributes]
        
        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model = self.model.to(self.device)
示例#20
0
    def load_model(self):
        logger.info("Loading model")
        if self.config.model in self.config.model_config:
            attributes = self.config.model_config[self.config.model]
        else:
            warnings.warn(f"Model {self.config.model}'s config not present. " +
                          "Continuing with empty config")
            attributes = OmegaConf.create()
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model = self.model.to(self.device)
示例#21
0
    def test_pretrained_model(self):
        self.model_config["training_head_type"] = "pretraining"
        pretrain_model = build_model(self.model_config)
        pretrain_model.model.eval()
        num_bbox_per_image = 10
        input_ids = torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long()
        attention_mask = torch.ones((1, 128)).long()
        token_type_ids = torch.zeros(1, 128).long()
        visual_embeddings = torch.rand(
            (1, num_bbox_per_image, self.vision_feature_size)
        ).float()
        image_attention_mask = torch.zeros((1, num_bbox_per_image)).long()
        visual_locations = torch.rand((1, num_bbox_per_image, 5)).float()
        masked_lm_labels = torch.zeros((1, 128), dtype=torch.long).fill_(-1)
        image_target = torch.zeros(1, num_bbox_per_image, self.vision_target_size)
        image_label = torch.ones(1, num_bbox_per_image).fill_(-1)
        pretrain_model.eval()

        with torch.no_grad():
            model_output = pretrain_model.model(
                input_ids=input_ids,
                image_feature=visual_embeddings,
                image_location=visual_locations,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                image_attention_mask=image_attention_mask,
                masked_lm_labels=masked_lm_labels,
                image_label=image_label,
                image_target=image_target,
            )
        script_model = torch.jit.script(pretrain_model.model)
        with torch.no_grad():
            script_output = script_model(
                input_ids=input_ids,
                image_feature=visual_embeddings,
                image_location=visual_locations,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                image_attention_mask=image_attention_mask,
                masked_lm_labels=masked_lm_labels,
                image_label=image_label,
                image_target=image_target,
            )
        self.assertEqual(
            model_output["masked_lm_loss"], script_output["masked_lm_loss"]
        )
示例#22
0
    def test_one_dim_feature_preprocessing(self):
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 256])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        masks = mmft._infer_masks(sample_list, input_ids)
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())

        mlm_labels = transformer_input["mlm_labels"]
        test_utils.compare_tensors(
            mlm_labels["combined_labels"],
            torch.full((2, 129), dtype=torch.long, fill_value=-1),
        )
示例#23
0
    def test_finetune_model(self):
        finetune_model = build_model(self.model_config)
        finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255,
                                              size=(128, )).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
示例#24
0
    def test_modal_end_token(self):
        finetune_model = build_model(self.model_config)
        finetune_model.eval()

        # Suppose 0 for <cls>, 1 for <pad> 2 for <sep>
        CLS = 0
        PAD = 1
        SEP = 2
        size = 128

        input_ids = torch.randint(low=0, high=30255, size=(size, )).long()
        input_mask = torch.ones(size).long()

        input_ids[0] = CLS
        length = torch.randint(low=2, high=size - 1, size=(1, ))
        input_ids[length] = SEP
        input_ids[length + 1:] = PAD
        input_mask[length + 1:] = 0

        test_sample = Sample()
        test_sample.input_ids = input_ids.clone()
        test_sample.input_mask = input_mask.clone()
        test_sample.segment_ids = torch.zeros(size).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample])

        mmbt_base = finetune_model.model.bert
        with torch.no_grad():
            actual_modal_end_token = mmbt_base.extract_modal_end_token(
                test_sample_list)

        expected_modal_end_token = torch.zeros([1]).fill_(SEP).long()
        self.assertTrue(
            torch.equal(actual_modal_end_token, expected_modal_end_token))
        self.assertTrue(
            torch.equal(test_sample_list.input_ids[0, :-1], input_ids[1:]))
        self.assertTrue(
            torch.equal(test_sample_list.input_mask[0, :-1], input_mask[1:]))
示例#25
0
 def test_finetune_bert_base(self):
     model = build_model(self.config.model_config[self.model_name])
     model.eval()
     self.assertTrue(
         test_utils.compare_torchscript_transformer_models(
             model, vocab_size=BERT_VOCAB_SIZE))
示例#26
0
 def test_load_save_finetune_model(self):
     model = build_model(self.config.model_config[self.model_name])
     self.assertTrue(test_utils.verify_torchscript_models(model))
示例#27
0
    def test_custom_feature_and_mask_preprocessing(self):
        extra_modality = MMFTransformerModalityConfig(
            type="my_random_feature",
            key="my_random_feature",
            embedding_dim=128,
            position_dim=4,
            segment_id=3,
            encoder=EncoderFactory.Config(type="identity"),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            extra_modality,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.text = torch.randint(0, 512, (2, 128))
        sample_list.text_mask = torch.ones(2, 128)
        sample_list.text_mask[:, 70:] = 0
        sample_list.my_random_feature = torch.rand(2, 4, 128)
        sample_list.my_random_feature_mask = torch.ones(2, 4)
        sample_list.my_random_feature_mask[:, 3:] = 0

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 256])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        self.assertEqual(input_ids["my_random_feature"].dim(), 3)
        self.assertEqual(list(input_ids["my_random_feature"].size()),
                         [2, 4, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))
        test_utils.compare_tensors(
            position_ids["my_random_feature"],
            torch.arange(0, 4).unsqueeze(0).expand((2, 4)),
        )

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        self.assertEqual(masks["text"].sum().item(), 140)
        self.assertEqual(masks["my_random_feature"].sum().item(), 6)

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
        test_utils.compare_tensors(
            segment_ids["my_random_feature"],
            torch.full((2, 4), dtype=torch.long, fill_value=3).long(),
        )
示例#28
0
 def test_load_save_pretrain_model(self):
     self.model_config["training_head_type"] = "pretraining"
     pretrain_model = build_model(self.model_config)
     self.assertTrue(test_utils.verify_torchscript_models(pretrain_model.model))
示例#29
0
 def test_load_save_finetune_model(self):
     self.model_config["training_head_type"] = "classification"
     finetune_model = build_model(self.model_config)
     self.assertTrue(test_utils.verify_torchscript_models(finetune_model.model))
示例#30
0
    def test_preprocessing_with_mvit_encoder(self):
        encoder_config = OmegaConf.create({
            "name":
            "pytorchvideo",
            "model_name":
            "mvit_base_32x3",
            "random_init":
            True,
            "drop_last_n_layers":
            0,
            "pooler_name":
            "cls",
            "spatial_size":
            224,
            "temporal_size":
            8,
            "head":
            None,
            "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
            "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
            "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            "pool_kv_stride_adaptive": [1, 8, 8],
            "pool_kvq_kernel": [3, 3, 3],
        })
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=768,
            position_dim=1,
            segment_id=0,
            encoder=encoder_config,
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand((2, 3, 8, 224, 224))
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 768])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())