예제 #1
0
 class Config(BaseModel.Config):
     model: str = "mmbt"
     # classification or pretraining
     training_head_type: str = "pretraining"
     bert_model_name: str = "bert-base-uncased"
     direct_features_input: bool = False
     freeze_text: bool = False
     freeze_modal: bool = False
     freeze_complete_base: bool = False
     finetune_lr_multiplier: float = 1
     # Dimension of the embedding finally returned by the modal encoder
     modal_hidden_size: int = 2048
     text_hidden_size: int = 768
     num_labels: int = 2
     # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
     modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config(
         type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config()
     )
     text_encoder: EncoderFactory.Config = TextEncoderFactory.Config(
         type=TextEncoderTypes.transformer,
         params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")),
     )
     use_modal_start_token: bool = True
     use_modal_end_token: bool = True
     fused_feature_only: bool = False
     output_dim: int = 768
예제 #2
0
 def test_mmft_from_build_model(self):
     modalities_config = [
         MMFTransformerModalityConfig(
             type="image",
             key="image",
             embedding_dim=256,
             position_dim=1,
             segment_id=0,
             encoder=ImageEncoderFactory.Config(
                 type=ImageEncoderTypes.resnet152,
                 params=ResNet152ImageEncoder.Config(pretrained=False),
             ),
         ),
         MMFTransformerModalityConfig(
             type="text",
             key="text",
             embedding_dim=756,
             position_dim=512,
             segment_id=1,
             encoder=TextEncoderFactory.Config(
                 type=TextEncoderTypes.identity),
         ),
     ]
     config = MMFTransformer.Config(modalities=modalities_config,
                                    num_labels=2)
     mmft = build_model(config)
     self.assertIsNotNone(mmft)
예제 #3
0
    def test_mmbt_from_params(self):
        # default init
        mmbt = MMBT.from_params(
            modal_encoder=ImageEncoder.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
            text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity),
        )

        config = OmegaConf.structured(
            MMBT.Config(
                modal_encoder=ImageEncoder.Config(
                    type=ImageEncoderTypes.resnet152,
                    params=ResNet152ImageEncoder.Config(pretrained=False),
                ),
                text_encoder=TextEncoder.Config(
                    type=TextEncoderTypes.identity),
            ))
        self.assertIsNotNone(mmbt)
        # Make sure that the config is created from MMBT.Config
        self.assertEqual(mmbt.config, config)
예제 #4
0
    def test_preprocessing_with_resnet_encoder(self):
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=2048,
            position_dim=1,
            segment_id=0,
            encoder=ImageEncoderFactory.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 3, 224, 224)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)

        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
예제 #5
0
 class Config(BaseTransformer.Config):
     model: str = "mmft"
     transformer_base: str = "bert-base-uncased"
     heads: List[BaseTransformerHead.Config] = field(
         default_factory=lambda: [MLP.Config()]
     )
     num_labels: int = MISSING
     initializer_range: float = 0.02
     initializer_mean: float = 0.0
     token_noise_std: float = 0.01
     token_noise_mean: float = 0.0
     layer_norm_weight_fill: float = 1.0
     random_initialize: bool = False
     freeze_transformer: bool = False
     freeze_image_encoder: bool = False
     tie_weight_to_encoder: Optional[str] = None
     finetune_lr_multiplier: float = 1
     backend: BaseTransformerBackendConfig = MMFTransformerBackendConfig(
         type="huggingface"
     )
     modalities: List[BaseTransformerModalityConfig] = field(
         default_factory=lambda: [
             MMFTransformerModalityConfig(
                 type="text",
                 key="text",
                 position_dim=512,
                 embedding_dim=768,
                 segment_id=0,
             ),
             MMFTransformerModalityConfig(
                 type="image",
                 key="image",
                 embedding_dim=2048,
                 position_dim=1,
                 segment_id=1,
                 # NOTE: One can also specify encoder in factory mode as
                 # encoder=ImageEncoderFactory.Config(
                 #   type="resnet152",
                 #   params=ResNet152ImageEncoder.Config()
                 # )
                 encoder=ResNet152ImageEncoder.Config(),
             ),
         ]
     )