Esempio n. 1
0
def standard_train(model: Model,
                   output_location: str,
                   dataset_hparams: hparams.DatasetHparams,
                   training_hparams: hparams.TrainingHparams,
                   start_step: Step = None,
                   verbose: bool = True,
                   evaluate_every_epoch: bool = True):
    """Train using the standard callbacks according to the provided hparams."""

    # If the model file for the end of training already exists in this location, do not train.
    iterations_per_epoch = datasets.registry.iterations_per_epoch(
        dataset_hparams)
    train_end_step = Step.from_str(training_hparams.training_steps,
                                   iterations_per_epoch)
    if (models.registry.exists(output_location, train_end_step)
            and get_platform().exists(paths.logger(output_location))):
        return

    train_loader = datasets.registry.get(dataset_hparams, train=True)
    test_loader = datasets.registry.get(dataset_hparams, train=False)
    callbacks = standard_callbacks.standard_callbacks(
        training_hparams,
        train_loader,
        test_loader,
        start_step=start_step,
        verbose=verbose,
        evaluate_every_epoch=evaluate_every_epoch)
    train(training_hparams,
          model,
          train_loader,
          output_location,
          callbacks,
          start_step=start_step)
Esempio n. 2
0
    def test_last_step(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(2, 11, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the final state that is saved matches the final state of the network.
        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 1)
        self.assertEqual(len(logger.get_data('test_loss')), 1)
        self.assertEqual(len(logger.get_data('train_accuracy')), 1)
        self.assertEqual(len(logger.get_data('test_accuracy')), 1)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
Esempio n. 3
0
    def test_first_step(self):
        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    end_step=Step.from_epoch(0, 1, len(self.train_loader)))

        # Check that the initial state has been saved.
        model_state_loc = paths.model(self.root,
                                      Step.zero(len(self.train_loader)))
        self.assertTrue(os.path.exists(model_state_loc))

        # Check that the model state at init reflects the saved state.
        self.model.load_state_dict(torch.load(model_state_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the logger file doesn't exist.
        self.assertFalse(os.path.exists(paths.logger(self.root)))
Esempio n. 4
0
 def save(self, location):
     if not get_platform().is_primary_process: return
     if not get_platform().exists(location):
         get_platform().makedirs(location)
     with get_platform().open(paths.logger(location), 'w') as fp:
         fp.write("phase,iteration,loss,accuracy,elapsed_time\n")
         fp.write(str(self))
Esempio n. 5
0
 def assertLevelFilesPresent(self,
                             level_root,
                             start_step,
                             end_step,
                             masks=False):
     with self.subTest(level_root=level_root):
         self.assertTrue(os.path.exists(paths.model(level_root,
                                                    start_step)))
         self.assertTrue(os.path.exists(paths.model(level_root, end_step)))
         self.assertTrue(os.path.exists(paths.logger(level_root)))
         if masks:
             self.assertTrue(os.path.exists(paths.mask(level_root)))
             self.assertTrue(
                 os.path.exists(paths.sparsity_report(level_root)))
Esempio n. 6
0
    def test_end_to_end(self):
        init_loc = paths.model(self.root, Step.zero(len(self.train_loader)))
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))

        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(0, 0, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        self.assertTrue(os.path.exists(init_loc))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the checkpoint file still exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the initial and final states match those that were saved.
        self.model.load_state_dict(torch.load(init_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 4)
        self.assertEqual(len(logger.get_data('test_loss')), 4)
        self.assertEqual(len(logger.get_data('train_accuracy')), 4)
        self.assertEqual(len(logger.get_data('test_accuracy')), 4)
Esempio n. 7
0
 def save(self, location, suffix=''):
     if not get_platform().is_primary_process: return
     if not get_platform().exists(location):
         get_platform().makedirs(location)
     with get_platform().open(paths.logger(location, suffix), 'w') as fp:
         fp.write(str(self))
Esempio n. 8
0
 def create_from_file(filename):
     with get_platform().open(paths.logger(filename)) as fp:
         as_str = fp.read()
     return MetricLogger.create_from_string(as_str)