def test_freeze_none(self): default_hparams = registry.get_default_hparams('cifar_resnet_20') model = registry.get(default_hparams.model_hparams) for k, v in model.named_parameters(): with self.subTest(tensor=k): self.assertTrue(v.requires_grad)
def test_freeze_all(self): default_hparams = registry.get_default_hparams('cifar_resnet_20') default_hparams.model_hparams.others_frozen = True default_hparams.model_hparams.output_frozen = True default_hparams.model_hparams.batchnorm_frozen = True model = registry.get(default_hparams.model_hparams) for k, v in model.named_parameters(): with self.subTest(tensor=k): self.assertFalse(v.requires_grad)
def test_freeze_output(self): default_hparams = registry.get_default_hparams('cifar_resnet_20') default_hparams.model_hparams.output_frozen = True model = registry.get(default_hparams.model_hparams) for k, v in model.named_parameters(): with self.subTest(tensor=k): if k in model.output_layer_names: self.assertFalse(v.requires_grad) else: self.assertTrue(v.requires_grad)
def test_freeze_batchnorm(self): default_hparams = registry.get_default_hparams('cifar_resnet_20') default_hparams.model_hparams.batchnorm_frozen = True model = registry.get(default_hparams.model_hparams) bn_names = [] for k, v in model.named_modules(): if isinstance(v, torch.nn.BatchNorm2d): bn_names += [k + '.weight', k + '.bias'] for k, v in model.named_parameters(): with self.subTest(tensor=k): if k in bn_names: self.assertFalse(v.requires_grad) else: self.assertTrue(v.requires_grad)
def test_save_load_exists(self): hp = registry.get_default_hparams('cifar_resnet_20') model = registry.get(hp.model_hparams) step = Step.from_iteration(27, 17) model_location = paths.model(self.root, step) model_state = TestSaveLoadExists.get_state(model) self.assertFalse(registry.exists(self.root, step)) self.assertFalse(os.path.exists(model_location)) # Test saving. model.save(self.root, step) self.assertTrue(registry.exists(self.root, step)) self.assertTrue(os.path.exists(model_location)) # Test loading. model = registry.get(hp.model_hparams) model.load_state_dict(torch.load(model_location)) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model)) model = registry.load(self.root, step, hp.model_hparams) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))