def test_build_optimizer_custom_model(self): model = MMBT.from_params() model.build() self.config.model = model.config.model self.config.model_config = model.config optimizer = build_optimizer(model, self.config) self.assertTrue(isinstance(optimizer, torch.optim.Optimizer)) self.assertEqual(len(optimizer.param_groups), 2)
def test_mmbt_from_params(self): # default init mmbt = MMBT.from_params( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), ) config = OmegaConf.structured( MMBT.Config( modal_encoder=ImageEncoder.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), text_encoder=TextEncoder.Config( type=TextEncoderTypes.identity), )) self.assertIsNotNone(mmbt) # Make sure that the config is created from MMBT.Config self.assertEqual(mmbt.config, config)
def test_mmbt_pretrained(self): test_utils.setup_proxy() mmbt = MMBT.from_params() self.assertIsNotNone(mmbt)
def test_mmbt_pretrained(self): mmbt = MMBT.from_params() self.assertIsNotNone(mmbt)