def standard_train(model: Model, output_location: str, dataset_hparams: hparams.DatasetHparams, training_hparams: hparams.TrainingHparams, start_step: Step = None, verbose: bool = True, evaluate_every_epoch: bool = True): """Train using the standard callbacks according to the provided hparams.""" # If the model file for the end of training already exists in this location, do not train. iterations_per_epoch = datasets.registry.iterations_per_epoch( dataset_hparams) train_end_step = Step.from_str(training_hparams.training_steps, iterations_per_epoch) if (models.registry.exists(output_location, train_end_step) and get_platform().exists(paths.logger(output_location))): return train_loader = datasets.registry.get(dataset_hparams, train=True) test_loader = datasets.registry.get(dataset_hparams, train=False) callbacks = standard_callbacks.standard_callbacks( training_hparams, train_loader, test_loader, start_step=start_step, verbose=verbose, evaluate_every_epoch=evaluate_every_epoch) train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step)
def test_last_step(self): train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(2, 11, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) self.assertTrue(os.path.exists(end_loc)) # Check that the final state that is saved matches the final state of the network. self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 1) self.assertEqual(len(logger.get_data('test_loss')), 1) self.assertEqual(len(logger.get_data('train_accuracy')), 1) self.assertEqual(len(logger.get_data('test_accuracy')), 1) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
def test_first_step(self): init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, end_step=Step.from_epoch(0, 1, len(self.train_loader))) # Check that the initial state has been saved. model_state_loc = paths.model(self.root, Step.zero(len(self.train_loader))) self.assertTrue(os.path.exists(model_state_loc)) # Check that the model state at init reflects the saved state. self.model.load_state_dict(torch.load(model_state_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) # Check that the checkpoint file exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the logger file doesn't exist. self.assertFalse(os.path.exists(paths.logger(self.root)))
def save(self, location): if not get_platform().is_primary_process: return if not get_platform().exists(location): get_platform().makedirs(location) with get_platform().open(paths.logger(location), 'w') as fp: fp.write("phase,iteration,loss,accuracy,elapsed_time\n") fp.write(str(self))
def assertLevelFilesPresent(self, level_root, start_step, end_step, masks=False): with self.subTest(level_root=level_root): self.assertTrue(os.path.exists(paths.model(level_root, start_step))) self.assertTrue(os.path.exists(paths.model(level_root, end_step))) self.assertTrue(os.path.exists(paths.logger(level_root))) if masks: self.assertTrue(os.path.exists(paths.mask(level_root))) self.assertTrue( os.path.exists(paths.sparsity_report(level_root)))
def test_end_to_end(self): init_loc = paths.model(self.root, Step.zero(len(self.train_loader))) end_loc = paths.model(self.root, Step.from_epoch(3, 0, len(self.train_loader))) init_state = TestStandardCallbacks.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=self.callbacks, start_step=Step.from_epoch(0, 0, len(self.train_loader)), end_step=Step.from_epoch(3, 0, len(self.train_loader))) end_state = TestStandardCallbacks.get_state(self.model) # Check that final state has been saved. self.assertTrue(os.path.exists(init_loc)) self.assertTrue(os.path.exists(end_loc)) # Check that the checkpoint file still exists. self.assertTrue(os.path.exists(paths.checkpoint(self.root))) # Check that the initial and final states match those that were saved. self.model.load_state_dict(torch.load(init_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(init_state, saved_state) self.model.load_state_dict(torch.load(end_loc)) saved_state = TestStandardCallbacks.get_state(self.model) self.assertStateEqual(end_state, saved_state) # Check that the logger has the right number of entries. self.assertTrue(os.path.exists(paths.logger(self.root))) logger = MetricLogger.create_from_file(self.root) self.assertEqual(len(logger.get_data('train_loss')), 4) self.assertEqual(len(logger.get_data('test_loss')), 4) self.assertEqual(len(logger.get_data('train_accuracy')), 4) self.assertEqual(len(logger.get_data('test_accuracy')), 4)
def save(self, location, suffix=''): if not get_platform().is_primary_process: return if not get_platform().exists(location): get_platform().makedirs(location) with get_platform().open(paths.logger(location, suffix), 'w') as fp: fp.write(str(self))
def create_from_file(filename): with get_platform().open(paths.logger(filename)) as fp: as_str = fp.read() return MetricLogger.create_from_string(as_str)