class Config(BaseModel.Config): model: str = "mmbt" # classification or pretraining training_head_type: str = "pretraining" bert_model_name: str = "bert-base-uncased" direct_features_input: bool = False freeze_text: bool = False freeze_modal: bool = False freeze_complete_base: bool = False finetune_lr_multiplier: float = 1 # Dimension of the embedding finally returned by the modal encoder modal_hidden_size: int = 2048 text_hidden_size: int = 768 num_labels: int = 2 # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] modal_encoder: Encoder.Config = ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config() ) text_encoder: Encoder.Config = TextEncoder.Config( type=TextEncoderTypes.transformer, params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")), ) use_modal_start_token: bool = True use_modal_end_token: bool = True fused_feature_only: bool = False output_dim: int = 768
def test_mmbt_from_params(self): # default init mmbt = MMBT.from_params( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), ) config = OmegaConf.structured( MMBT.Config( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config( type=TextEncoderTypes.identity), )) self.assertIsNotNone(mmbt) # Make sure that the config is created from MMBT.Config self.assertEqual(mmbt.config, config)
def build_text_encoder(config, *args, **kwargs): from mmf.modules.encoders import TextEncoder text_encoder = TextEncoder(config, *args, **kwargs) return text_encoder.module