Esempio n. 1
0
    def test_freeze_none(self):
        default_hparams = registry.get_default_hparams('cifar_resnet_20')
        model = registry.get(default_hparams.model_hparams)

        for k, v in model.named_parameters():
            with self.subTest(tensor=k):
                self.assertTrue(v.requires_grad)
Esempio n. 2
0
    def test_freeze_all(self):
        default_hparams = registry.get_default_hparams('cifar_resnet_20')
        default_hparams.model_hparams.others_frozen = True
        default_hparams.model_hparams.output_frozen = True
        default_hparams.model_hparams.batchnorm_frozen = True
        model = registry.get(default_hparams.model_hparams)

        for k, v in model.named_parameters():
            with self.subTest(tensor=k):
                self.assertFalse(v.requires_grad)
Esempio n. 3
0
    def test_freeze_output(self):
        default_hparams = registry.get_default_hparams('cifar_resnet_20')
        default_hparams.model_hparams.output_frozen = True
        model = registry.get(default_hparams.model_hparams)

        for k, v in model.named_parameters():
            with self.subTest(tensor=k):
                if k in model.output_layer_names:
                    self.assertFalse(v.requires_grad)
                else:
                    self.assertTrue(v.requires_grad)
Esempio n. 4
0
    def test_freeze_batchnorm(self):
        default_hparams = registry.get_default_hparams('cifar_resnet_20')
        default_hparams.model_hparams.batchnorm_frozen = True
        model = registry.get(default_hparams.model_hparams)

        bn_names = []
        for k, v in model.named_modules():
            if isinstance(v, torch.nn.BatchNorm2d):
                bn_names += [k + '.weight', k + '.bias']

        for k, v in model.named_parameters():
            with self.subTest(tensor=k):
                if k in bn_names:
                    self.assertFalse(v.requires_grad)
                else:
                    self.assertTrue(v.requires_grad)
Esempio n. 5
0
    def test_save_load_exists(self):
        hp = registry.get_default_hparams('cifar_resnet_20')
        model = registry.get(hp.model_hparams)
        step = Step.from_iteration(27, 17)
        model_location = paths.model(self.root, step)
        model_state = TestSaveLoadExists.get_state(model)

        self.assertFalse(registry.exists(self.root, step))
        self.assertFalse(os.path.exists(model_location))

        # Test saving.
        model.save(self.root, step)
        self.assertTrue(registry.exists(self.root, step))
        self.assertTrue(os.path.exists(model_location))

        # Test loading.
        model = registry.get(hp.model_hparams)
        model.load_state_dict(torch.load(model_location))
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))

        model = registry.load(self.root, step, hp.model_hparams)
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))