def test_level3_4it_pretrain2it(self): self.desc.pretrain_dataset_hparams = copy.deepcopy( self.desc.dataset_hparams) self.desc.pretrain_training_hparams = copy.deepcopy( self.desc.training_hparams) self.desc.pretrain_training_hparams.training_steps = '2it' self.desc.training_hparams.training_steps = '4it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() # Check that the pretrain weights are present. pretrain_root = self.desc.run_path(2, 'pretrain') self.assertLevelFilesPresent(pretrain_root, self.to_step('0it'), self.to_step('2it'), masks=False) # Load the pretrain and level0 start weights to ensure they're the same. pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'), self.desc.pretrain_end_step) pretrain_end_weights = { k: v.numpy() for k, v in torch.load(pretrain_end_weights).items() } level0_weights = paths.model(self.desc.run_path(2, 0), self.desc.train_start_step) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } self.assertStateEqual(pretrain_end_weights, level0_weights) # Evaluate each of the pruning levels. for level in range(0, 2): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('2it'), self.to_step('4it')) # Ensure that the initial weights are a masked version of the level 0 weights # (which are identical to the weights at the end of pretraining). mask = Mask.load(level_root).numpy() level_weights = paths.model(level_root, self.desc.train_start_step) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def test_level3_2it(self): self.desc.training_hparams.training_steps = '2it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() level0_weights = paths.model(self.desc.run_path(2, 0), self.to_step('0it')) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } for level in range(0, 4): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) # Check the mask. pct = 0.8**level mask = Mask.load(level_root).numpy() # Check the mask itself. total, total_present = 0.0, 0.0 for v in mask.values(): total += v.size total_present += np.sum(v) self.assertTrue(np.allclose(pct, total_present / total, atol=0.01)) # Check the sparsity report. with open(paths.sparsity_report(level_root)) as fp: sparsity_report = json.loads(fp.read()) self.assertTrue( np.allclose(pct, sparsity_report['unpruned'] / sparsity_report['total'], atol=0.01)) # Ensure that the initial weights are a masked version of the level 0 weights. level_weights = paths.model(level_root, self.to_step('0it')) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def test_level0_2it(self): self.desc.training_hparams.training_steps = '2it' LotteryRunner(replicate=2, levels=0, desc=self.desc, verbose=False).run() level_root = self.desc.run_path(2, 0) # Ensure the important files are there. self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) # Ensure that the mask is all 1's. mask = Mask.load(level_root) for v in mask.numpy().values(): self.assertTrue(np.all(np.equal(v, 1))) with open(paths.sparsity_report(level_root)) as fp: sparsity_report = json.loads(fp.read()) self.assertEqual( sparsity_report['unpruned'] / sparsity_report['total'], 1)