Example #1
0
    def test_save_load_exists(self):
        self.assertFalse(Mask.exists(self.root))
        self.assertFalse(os.path.exists(paths.mask(self.root)))

        m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])})
        m.save(self.root)
        self.assertTrue(os.path.exists(paths.mask(self.root)))
        self.assertTrue(Mask.exists(self.root))

        m2 = Mask.load(self.root)
        self.assertEqual(len(m2), 2)
        self.assertEqual(len(m2.keys()), 2)
        self.assertEqual(len(m2.values()), 2)
        self.assertEqual(set(m2.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world']))
Example #2
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(new_location)
        else:
            old_location = self.desc.run_path(self.replicate, level-1)
            model = models.registry.load(old_location, self.desc.train_end_step,
                                         self.desc.model_hparams, self.desc.train_outputs)
            pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location)).save(new_location)
Example #3
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(
                new_location)  # level=0일때는 mask 다 1 => weight다 살리기
        else:
            old_location = self.desc.run_path(
                self.replicate,
                level - 1)  # 아니라면 old location = 직전 level 에 저장된 path인 run_path
            model = models.registry.load(
                old_location, self.desc.train_end_step,
                self.desc.model_hparams,
                self.desc.train_outputs)  # pruning이기때문에 train_end_step 불러오는것임!

            pruning.registry.get(self.desc.pruning_hparams)(
                model, Mask.load(old_location)
            ).save(
                new_location
            )  # registry.get 에는 return partial 부분에 .prune이 있어 프루닝이 되고 이후 new_location에 저장.