def test_create_restore_delete(self): # Create the hyperparameters and objects to save. hp = models.registry.get_default_hparams('cifar_resnet_20') model = models.registry.get(hp.model_hparams) optimizer = optimizers.get_optimizer(hp.training_hparams, model) dataloader = datasets.registry.get(hp.dataset_hparams) step = Step.from_epoch(13, 27, 400) # Run one step of SGD. examples, labels = next(iter(dataloader)) optimizer.zero_grad() model.train() model.loss_criterion(model(examples), labels).backward() optimizer.step() # Create a fake logger. logger = MetricLogger() logger.add('test_accuracy', Step.from_epoch(0, 0, 400), 0.1) logger.add('test_accuracy', Step.from_epoch(10, 0, 400), 0.5) logger.add('test_accuracy', Step.from_epoch(100, 0, 400), 0.8) # Save a checkpoint. checkpointing.save_checkpoint_callback(self.root, step, model, optimizer, logger) self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Create new models. model2 = models.registry.get(hp.model_hparams) optimizer2 = optimizers.get_optimizer(hp.training_hparams, model) # Ensure the new model has different weights. sd1, sd2 = model.state_dict(), model2.state_dict() for k in model.prunable_layer_names: self.assertFalse(np.array_equal(sd1[k].numpy(), sd2[k].numpy())) self.assertIn('momentum_buffer', optimizer.state[optimizer.param_groups[0]['params'][0]]) self.assertNotIn('momentum_buffer', optimizer2.state[optimizer.param_groups[0]['params'][0]]) # Restore the checkpointt. step2, logger2 = checkpointing.restore_checkpoint(self.root, model2, optimizer2, 400) self.assertTrue(os.path.exists(paths.checkpoint(self.root))) self.assertEqual(step, step2) self.assertEqual(str(logger), str(logger2)) # Ensure the new model is now the same. sd1, sd2 = model.state_dict(), model2.state_dict() self.assertEqual(set(sd1.keys()), set(sd2.keys())) for k in sd1: self.assertTrue(np.array_equal(sd1[k].numpy(), sd2[k].numpy())) # Ensure the new optimizer is now the same. mom1 = optimizer.state[optimizer.param_groups[0]['params'][0]]['momentum_buffer'] mom2 = optimizer2.state[optimizer.param_groups[0]['params'][0]]['momentum_buffer'] self.assertTrue(np.array_equal(mom1.numpy(), mom2.numpy()))
def restore_checkpoint(output_location, model, optimizer, iterations_per_epoch): checkpoint_location = paths.checkpoint(output_location) if not get_platform().exists(checkpoint_location): return None, None checkpoint = get_platform().load_model(checkpoint_location, map_location=torch.device('cpu')) # Handle DataParallel. module_in_name = get_platform().is_parallel if module_in_name and not all( k.startswith('module.') for k in checkpoint['model_state_dict']): checkpoint['model_state_dict'] = { 'module.' + k: v for k, v in checkpoint['model_state_dict'].items() } elif all(k.startswith('module.') for k in checkpoint['model_state_dict']) and not module_in_name: checkpoint['model_state_dict'] = { k[len('module.'):]: v for k, v in checkpoint['model_state_dict'].items() } model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) step = Step.from_epoch(checkpoint['ep'], checkpoint['it'], iterations_per_epoch) logger = MetricLogger.create_from_string(checkpoint['logger']) return step, logger
def test_last_step(self): train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(2, 11, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) self.assertTrue(os.path.exists(end_loc)) # Check that the final state that is saved matches the final state of the network. self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 1) self.assertEqual(len(logger.get_data('test_loss')), 1) self.assertEqual(len(logger.get_data('train_accuracy')), 1) self.assertEqual(len(logger.get_data('test_accuracy')), 1) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
def test_first_step(self): init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, end_step=Step.from_epoch(0, 1, len(self.train_loader))) # Check that the initial state has been saved. model_state_loc = paths.model(self.root, Step.zero(len(self.train_loader))) self.assertTrue(os.path.exists(model_state_loc)) # Check that the model state at init reflects the saved state. self.model.load_state_dict(torch.load(model_state_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the logger file doesn't exist. self.assertFalse(os.path.exists(paths.logger(self.root)))
def save_checkpoint_callback(output_location, step, model, optimizer, logger): if get_platform().is_primary_process: get_platform().save_model({ 'ep': step.ep, 'it': step.it, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'logger': str(logger), }, paths.checkpoint(output_location)) get_platform().barrier()
def test_end_to_end(self): init_loc = paths.model(self.root, Step.zero(len(self.train_loader))) end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(0, 0, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. self.assertTrue(os.path.exists(init_loc)) self.assertTrue(os.path.exists(end_loc)) # Check that the checkpoint file still exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the initial and final states match those that were saved. self.model.load_state_dict(torch.load(init_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 4) self.assertEqual(len(logger.get_data('test_loss')), 4) self.assertEqual(len(logger.get_data('train_accuracy')), 4) self.assertEqual(len(logger.get_data('test_accuracy')), 4)