def _establish_initial_weights(self): location = self.desc.run_path(self.replicate, 0) if models.registry.exists(location, self.desc.train_start_step): if get_platform().is_primary_process: print('Initial weights loaded from {}'.format( paths.model(location, self.desc.train_start_step))) return new_model = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs) # If there was a pretrained model, retrieve its final weights and adapt them for training. if self.desc.pretrain_training_hparams is not None: pretrain_loc = self.desc.run_path(self.replicate, 'pretrain') if get_platform().is_primary_process: print('Initial weights loaded from pretrained checkpoint {}'. format( paths.model(pretrain_loc, self.desc.pretrain_end_step))) old = models.registry.load(pretrain_loc, self.desc.pretrain_end_step, self.desc.model_hparams, self.desc.pretrain_outputs) state_dict = {k: v for k, v in old.state_dict().items()} # Select a new output layer if number of classes differs. if self.desc.train_outputs != self.desc.pretrain_outputs: state_dict.update({ k: new_model.state_dict()[k] for k in new_model.output_layer_names }) new_model.load_state_dict(state_dict) new_model.save(location, self.desc.train_start_step)
def test_save(self): state1 = self.get_state(self.model) mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) pruned_model.save(self.root, Step.zero(20)) self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20)))) self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20)))) self.assertStateEqual(state1, self.get_state(self.model))
def test_level3_4it_pretrain2it(self): self.desc.pretrain_dataset_hparams = copy.deepcopy( self.desc.dataset_hparams) self.desc.pretrain_training_hparams = copy.deepcopy( self.desc.training_hparams) self.desc.pretrain_training_hparams.training_steps = '2it' self.desc.training_hparams.training_steps = '4it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() # Check that the pretrain weights are present. pretrain_root = self.desc.run_path(2, 'pretrain') self.assertLevelFilesPresent(pretrain_root, self.to_step('0it'), self.to_step('2it'), masks=False) # Load the pretrain and level0 start weights to ensure they're the same. pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'), self.desc.pretrain_end_step) pretrain_end_weights = { k: v.numpy() for k, v in torch.load(pretrain_end_weights).items() } level0_weights = paths.model(self.desc.run_path(2, 0), self.desc.train_start_step) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } self.assertStateEqual(pretrain_end_weights, level0_weights) # Evaluate each of the pruning levels. for level in range(0, 2): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('2it'), self.to_step('4it')) # Ensure that the initial weights are a masked version of the level 0 weights # (which are identical to the weights at the end of pretraining). mask = Mask.load(level_root).numpy() level_weights = paths.model(level_root, self.desc.train_start_step) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def assertLevelFilesPresent(self, level_root, start_step, end_step, masks=False): with self.subTest(level_root=level_root): self.assertTrue(os.path.exists(paths.model(level_root, start_step))) self.assertTrue(os.path.exists(paths.model(level_root, end_step))) self.assertTrue(os.path.exists(paths.logger(level_root))) if masks: self.assertTrue(os.path.exists(paths.mask(level_root))) self.assertTrue( os.path.exists(paths.sparsity_report(level_root)))
def test_last_step(self): train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(2, 11, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) self.assertTrue(os.path.exists(end_loc)) # Check that the final state that is saved matches the final state of the network. self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 1) self.assertEqual(len(logger.get_data('test_loss')), 1) self.assertEqual(len(logger.get_data('train_accuracy')), 1) self.assertEqual(len(logger.get_data('test_accuracy')), 1) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
def test_first_step(self): init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, end_step=Step.from_epoch(0, 1, len(self.train_loader))) # Check that the initial state has been saved. model_state_loc = paths.model(self.root, Step.zero(len(self.train_loader))) self.assertTrue(os.path.exists(model_state_loc)) # Check that the model state at init reflects the saved state. self.model.load_state_dict(torch.load(model_state_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the logger file doesn't exist. self.assertFalse(os.path.exists(paths.logger(self.root)))
def save(self, save_location: str, save_step: Step, suffix: str = ""): if not get_platform().is_primary_process: return if not get_platform().exists(save_location): get_platform().makedirs(save_location) get_platform().save_model( self.state_dict(), paths.model(save_location, save_step, suffix=suffix))
def test_level3_2it(self): self.desc.training_hparams.training_steps = '2it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() level0_weights = paths.model(self.desc.run_path(2, 0), self.to_step('0it')) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } for level in range(0, 4): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) # Check the mask. pct = 0.8**level mask = Mask.load(level_root).numpy() # Check the mask itself. total, total_present = 0.0, 0.0 for v in mask.values(): total += v.size total_present += np.sum(v) self.assertTrue(np.allclose(pct, total_present / total, atol=0.01)) # Check the sparsity report. with open(paths.sparsity_report(level_root)) as fp: sparsity_report = json.loads(fp.read()) self.assertTrue( np.allclose(pct, sparsity_report['unpruned'] / sparsity_report['total'], atol=0.01)) # Ensure that the initial weights are a masked version of the level 0 weights. level_weights = paths.model(level_root, self.to_step('0it')) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def load(save_location: str, save_step: Step, model_hparams: ModelHparams, outputs=None): state_dict = get_platform().load_model( paths.model(save_location, save_step)) model = get(model_hparams, outputs) model.load_state_dict(state_dict) return model
def test_end_to_end(self): init_loc = paths.model(self.root, Step.zero(len(self.train_loader))) end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(0, 0, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. self.assertTrue(os.path.exists(init_loc)) self.assertTrue(os.path.exists(end_loc)) # Check that the checkpoint file still exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the initial and final states match those that were saved. self.model.load_state_dict(torch.load(init_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 4) self.assertEqual(len(logger.get_data('test_loss')), 4) self.assertEqual(len(logger.get_data('train_accuracy')), 4) self.assertEqual(len(logger.get_data('test_accuracy')), 4)
def run(self): if get_platform().exists( paths.model(self.desc.run_path(self.replicate), self.desc.end_step)): return if self.verbose and get_platform().is_primary_process: print('=' * 82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-' * 82) print(self.desc.display) print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '=' * 82 + '\n') self.desc.save(self.desc.run_path(self.replicate)) train.standard_train(models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate), self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch, verbose=self.verbose, weight_save_steps=self.weight_save_steps)
def test_save_load_exists(self): hp = registry.get_default_hparams('cifar_resnet_20') model = registry.get(hp.model_hparams) step = Step.from_iteration(27, 17) model_location = paths.model(self.root, step) model_state = TestSaveLoadExists.get_state(model) self.assertFalse(registry.exists(self.root, step)) self.assertFalse(os.path.exists(model_location)) # Test saving. model.save(self.root, step) self.assertTrue(registry.exists(self.root, step)) self.assertTrue(os.path.exists(model_location)) # Test loading. model = registry.get(hp.model_hparams) model.load_state_dict(torch.load(model_location)) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model)) model = registry.load(self.root, step, hp.model_hparams) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))
def exists(save_location, save_step, suffix=""): return get_platform().exists(paths.model(save_location, save_step, suffix))
def exists(save_location, save_step): return get_platform().exists(paths.model(save_location, save_step))