Example #1
0
    def test_level3_4it_pretrain2it(self):
        self.desc.pretrain_dataset_hparams = copy.deepcopy(
            self.desc.dataset_hparams)
        self.desc.pretrain_training_hparams = copy.deepcopy(
            self.desc.training_hparams)
        self.desc.pretrain_training_hparams.training_steps = '2it'
        self.desc.training_hparams.training_steps = '4it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        # Check that the pretrain weights are present.
        pretrain_root = self.desc.run_path(2, 'pretrain')
        self.assertLevelFilesPresent(pretrain_root,
                                     self.to_step('0it'),
                                     self.to_step('2it'),
                                     masks=False)

        # Load the pretrain and level0 start weights to ensure they're the same.
        pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'),
                                           self.desc.pretrain_end_step)
        pretrain_end_weights = {
            k: v.numpy()
            for k, v in torch.load(pretrain_end_weights).items()
        }

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.desc.train_start_step)
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        self.assertStateEqual(pretrain_end_weights, level0_weights)

        # Evaluate each of the pruning levels.
        for level in range(0, 2):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('2it'),
                                         self.to_step('4it'))

            # Ensure that the initial weights are a masked version of the level 0 weights
            # (which are identical to the weights at the end of pretraining).
            mask = Mask.load(level_root).numpy()
            level_weights = paths.model(level_root, self.desc.train_start_step)
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Example #2
0
    def test_level3_2it(self):
        self.desc.training_hparams.training_steps = '2it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.to_step('0it'))
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        for level in range(0, 4):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('0it'),
                                         self.to_step('2it'))

            # Check the mask.
            pct = 0.8**level
            mask = Mask.load(level_root).numpy()

            # Check the mask itself.
            total, total_present = 0.0, 0.0
            for v in mask.values():
                total += v.size
                total_present += np.sum(v)
            self.assertTrue(np.allclose(pct, total_present / total, atol=0.01))

            # Check the sparsity report.
            with open(paths.sparsity_report(level_root)) as fp:
                sparsity_report = json.loads(fp.read())
            self.assertTrue(
                np.allclose(pct,
                            sparsity_report['unpruned'] /
                            sparsity_report['total'],
                            atol=0.01))

            # Ensure that the initial weights are a masked version of the level 0 weights.
            level_weights = paths.model(level_root, self.to_step('0it'))
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Example #3
0
    def test_level0_2it(self):
        self.desc.training_hparams.training_steps = '2it'
        LotteryRunner(replicate=2, levels=0, desc=self.desc,
                      verbose=False).run()
        level_root = self.desc.run_path(2, 0)

        # Ensure the important files are there.
        self.assertLevelFilesPresent(level_root, self.to_step('0it'),
                                     self.to_step('2it'))

        # Ensure that the mask is all 1's.
        mask = Mask.load(level_root)
        for v in mask.numpy().values():
            self.assertTrue(np.all(np.equal(v, 1)))
        with open(paths.sparsity_report(level_root)) as fp:
            sparsity_report = json.loads(fp.read())
        self.assertEqual(
            sparsity_report['unpruned'] / sparsity_report['total'], 1)