Exemplo n.º 1
0
    def test_modality_key_preprocessing(self):
        self._text_modality_config.key = "body"
        second_text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="ocr",
            embedding_dim=756,
            position_dim=128,
            segment_id=2,
            encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            second_text_modality_config,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.body = torch.randint(0, 512, (2, 128))
        sample_list.ocr = torch.randint(0, 512, (2, 128))
        sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 128))
        lm_labels_sum = sample_list.lm_label_ids.sum().item() * 2

        transformer_input = mmft.preprocess_sample(sample_list)
        self._compare_processed_for_multimodality(transformer_input,
                                                  lm_labels_sum)
Exemplo n.º 2
0
 class Config(BaseModel.Config):
     model: str = "mmbt"
     # classification or pretraining
     training_head_type: str = "pretraining"
     bert_model_name: str = "bert-base-uncased"
     direct_features_input: bool = False
     freeze_text: bool = False
     freeze_modal: bool = False
     freeze_complete_base: bool = False
     finetune_lr_multiplier: float = 1
     # Dimension of the embedding finally returned by the modal encoder
     modal_hidden_size: int = 2048
     text_hidden_size: int = 768
     num_labels: int = 2
     # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
     modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config(
         type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config()
     )
     text_encoder: EncoderFactory.Config = TextEncoderFactory.Config(
         type=TextEncoderTypes.transformer,
         params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")),
     )
     use_modal_start_token: bool = True
     use_modal_end_token: bool = True
     fused_feature_only: bool = False
     output_dim: int = 768
Exemplo n.º 3
0
 def test_mmft_from_build_model(self):
     modalities_config = [
         MMFTransformerModalityConfig(
             type="image",
             key="image",
             embedding_dim=256,
             position_dim=1,
             segment_id=0,
             encoder=ImageEncoderFactory.Config(
                 type=ImageEncoderTypes.resnet152,
                 params=ResNet152ImageEncoder.Config(pretrained=False),
             ),
         ),
         MMFTransformerModalityConfig(
             type="text",
             key="text",
             embedding_dim=756,
             position_dim=512,
             segment_id=1,
             encoder=TextEncoderFactory.Config(
                 type=TextEncoderTypes.identity),
         ),
     ]
     config = MMFTransformer.Config(modalities=modalities_config,
                                    num_labels=2)
     mmft = build_model(config)
     self.assertIsNotNone(mmft)
Exemplo n.º 4
0
    def test_mmf_from_params_encoder_factory(self):
        modalities_config = [
            MMFTransformerModalityConfig(
                type="image",
                key="image",
                embedding_dim=256,
                position_dim=1,
                segment_id=0,
                encoder=ImageEncoderFactory.Config(
                    type=ImageEncoderTypes.identity),
            ),
            MMFTransformerModalityConfig(
                type="text",
                key="text",
                embedding_dim=756,
                position_dim=512,
                segment_id=0,
                encoder=TextEncoderFactory.Config(
                    type=TextEncoderTypes.identity),
            ),
        ]
        mmft = MMFTransformer.from_params(modalities=modalities_config,
                                          num_labels=2)
        mmft.build()

        config = OmegaConf.structured(
            MMFTransformer.Config(modalities=modalities_config, num_labels=2))
        self.assertIsNotNone(mmft)
        self.assertEqual(mmft.config, config)
Exemplo n.º 5
0
    def test_tie_mlm_head_weight_to_encoder(self):
        self._text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="text",
            embedding_dim=768,
            position_dim=128,
            segment_id=0,
            encoder=TextEncoderFactory.Config(
                type=TextEncoderTypes.transformer),
        )
        heads = [MLM.Config()]
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(
            heads=heads,
            modalities=modalities_config,
            num_labels=2,
            tie_weight_to_encoder="text",
        )
        mmft = build_model(config)

        test_utils.compare_tensors(
            mmft.heads[0].cls.predictions.decoder.weight,
            mmft.encoders["text"].embeddings.word_embeddings.weight,
        )
Exemplo n.º 6
0
 def test_mmbt_directly_from_config(self):
     config = OmegaConf.structured(
         MMBT.Config(
             modal_encoder=ImageEncoderFactory.Config(
                 type=ImageEncoderTypes.resnet152,
                 params=ResNet152ImageEncoder.Config(pretrained=False),
             ),
             text_encoder=TextEncoderFactory.Config(
                 type=TextEncoderTypes.identity),
         ))
     mmbt = MMBT(config)
     self.assertIsNotNone(mmbt)
     # Make sure that the config is created from MMBT.Config
     self.assertEqual(mmbt.config, config)
Exemplo n.º 7
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self._image_modality_config = MMFTransformerModalityConfig(
         type="image",
         key="image",
         embedding_dim=256,
         position_dim=1,
         segment_id=0,
         encoder=ImageEncoderFactory.Config(
             type=ImageEncoderTypes.identity),
     )
     self._text_modality_config = MMFTransformerModalityConfig(
         type="text",
         key="text",
         embedding_dim=756,
         position_dim=128,
         segment_id=1,
         encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
     )