def test_modality_key_preprocessing(self): self._text_modality_config.key = "body" second_text_modality_config = MMFTransformerModalityConfig( type="text", key="ocr", embedding_dim=756, position_dim=128, segment_id=2, encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity), ) modalities_config = [ self._image_modality_config, self._text_modality_config, second_text_modality_config, ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 256) sample_list.body = torch.randint(0, 512, (2, 128)) sample_list.ocr = torch.randint(0, 512, (2, 128)) sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 128)) lm_labels_sum = sample_list.lm_label_ids.sum().item() * 2 transformer_input = mmft.preprocess_sample(sample_list) self._compare_processed_for_multimodality(transformer_input, lm_labels_sum)
class Config(BaseModel.Config): model: str = "mmbt" # classification or pretraining training_head_type: str = "pretraining" bert_model_name: str = "bert-base-uncased" direct_features_input: bool = False freeze_text: bool = False freeze_modal: bool = False freeze_complete_base: bool = False finetune_lr_multiplier: float = 1 # Dimension of the embedding finally returned by the modal encoder modal_hidden_size: int = 2048 text_hidden_size: int = 768 num_labels: int = 2 # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config() ) text_encoder: EncoderFactory.Config = TextEncoderFactory.Config( type=TextEncoderTypes.transformer, params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")), ) use_modal_start_token: bool = True use_modal_end_token: bool = True fused_feature_only: bool = False output_dim: int = 768
def test_mmft_from_build_model(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=756, position_dim=512, segment_id=1, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.identity), ), ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) self.assertIsNotNone(mmft)
def test_mmf_from_params_encoder_factory(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.identity), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=756, position_dim=512, segment_id=0, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.identity), ), ] mmft = MMFTransformer.from_params(modalities=modalities_config, num_labels=2) mmft.build() config = OmegaConf.structured( MMFTransformer.Config(modalities=modalities_config, num_labels=2)) self.assertIsNotNone(mmft) self.assertEqual(mmft.config, config)
def test_tie_mlm_head_weight_to_encoder(self): self._text_modality_config = MMFTransformerModalityConfig( type="text", key="text", embedding_dim=768, position_dim=128, segment_id=0, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.transformer), ) heads = [MLM.Config()] modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config( heads=heads, modalities=modalities_config, num_labels=2, tie_weight_to_encoder="text", ) mmft = build_model(config) test_utils.compare_tensors( mmft.heads[0].cls.predictions.decoder.weight, mmft.encoders["text"].embeddings.word_embeddings.weight, )
def test_mmbt_directly_from_config(self): config = OmegaConf.structured( MMBT.Config( modal_encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoderFactory.Config( type=TextEncoderTypes.identity), )) mmbt = MMBT(config) self.assertIsNotNone(mmbt) # Make sure that the config is created from MMBT.Config self.assertEqual(mmbt.config, config)
def setUp(self): test_utils.setup_proxy() setup_imports() self._image_modality_config = MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.identity), ) self._text_modality_config = MMFTransformerModalityConfig( type="text", key="text", embedding_dim=756, position_dim=128, segment_id=1, encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity), )