def test_mmft_from_params(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=IdentityEncoder.Config(), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=768, position_dim=512, segment_id=1, encoder=IdentityEncoder.Config(), ), ] mmft = MMFTransformer.from_params(modalities=modalities_config, num_labels=2) mmft.build() config = OmegaConf.structured( MMFTransformer.Config(modalities=modalities_config, num_labels=2)) self.assertIsNotNone(mmft) self.assertEqual(mmft.config, config)
def test_mmft_pretrained(self): mmft = MMFTransformer.from_params(num_labels=2) self.assertIsNotNone(mmft)