class Config(BaseModel.Config): model: str = "mmbt" # classification or pretraining training_head_type: str = "pretraining" bert_model_name: str = "bert-base-uncased" direct_features_input: bool = False freeze_text: bool = False freeze_modal: bool = False freeze_complete_base: bool = False finetune_lr_multiplier: float = 1 # Dimension of the embedding finally returned by the modal encoder modal_hidden_size: int = 2048 text_hidden_size: int = 768 num_labels: int = 2 # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config() ) text_encoder: EncoderFactory.Config = TextEncoderFactory.Config( type=TextEncoderTypes.transformer, params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")), ) use_modal_start_token: bool = True use_modal_end_token: bool = True fused_feature_only: bool = False output_dim: int = 768
def test_mmft_from_build_model(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=756, position_dim=512, segment_id=1, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.identity), ), ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) self.assertIsNotNone(mmft)
def test_mmbt_from_params(self): # default init mmbt = MMBT.from_params( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), ) config = OmegaConf.structured( MMBT.Config( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config( type=TextEncoderTypes.identity), )) self.assertIsNotNone(mmbt) # Make sure that the config is created from MMBT.Config self.assertEqual(mmbt.config, config)
def test_preprocessing_with_resnet_encoder(self): self._image_modality_config = MMFTransformerModalityConfig( type="image", key="image", embedding_dim=2048, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), ) modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 3, 224, 224) sample_list.text = torch.randint(0, 512, (2, 128)) transformer_input = mmft.preprocess_sample(sample_list) input_ids = transformer_input["input_ids"] self.assertEqual(input_ids["image"].dim(), 3) self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048]) self.assertEqual(input_ids["text"].dim(), 2) self.assertEqual(list(input_ids["text"].size()), [2, 128]) position_ids = transformer_input["position_ids"] test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors( position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))) masks = transformer_input["masks"] test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) segment_ids = transformer_input["segment_ids"] test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
class Config(BaseTransformer.Config): model: str = "mmft" transformer_base: str = "bert-base-uncased" heads: List[BaseTransformerHead.Config] = field( default_factory=lambda: [MLP.Config()] ) num_labels: int = MISSING initializer_range: float = 0.02 initializer_mean: float = 0.0 token_noise_std: float = 0.01 token_noise_mean: float = 0.0 layer_norm_weight_fill: float = 1.0 random_initialize: bool = False freeze_transformer: bool = False freeze_image_encoder: bool = False tie_weight_to_encoder: Optional[str] = None finetune_lr_multiplier: float = 1 backend: BaseTransformerBackendConfig = MMFTransformerBackendConfig( type="huggingface" ) modalities: List[BaseTransformerModalityConfig] = field( default_factory=lambda: [ MMFTransformerModalityConfig( type="text", key="text", position_dim=512, embedding_dim=768, segment_id=0, ), MMFTransformerModalityConfig( type="image", key="image", embedding_dim=2048, position_dim=1, segment_id=1, # NOTE: One can also specify encoder in factory mode as # encoder=ImageEncoderFactory.Config( # type="resnet152", # params=ResNet152ImageEncoder.Config() # ) encoder=ResNet152ImageEncoder.Config(), ), ] )