Esempio n. 1
0
def standard_train(model: Model,
                   output_location: str,
                   dataset_hparams: hparams.DatasetHparams,
                   training_hparams: hparams.TrainingHparams,
                   start_step: Step = None,
                   verbose: bool = True,
                   evaluate_every_epoch: bool = True):
    """Train using the standard callbacks according to the provided hparams."""

    # If the model file for the end of training already exists in this location, do not train.
    iterations_per_epoch = datasets.registry.iterations_per_epoch(
        dataset_hparams)
    train_end_step = Step.from_str(training_hparams.training_steps,
                                   iterations_per_epoch)
    if (models.registry.exists(output_location, train_end_step)
            and get_platform().exists(paths.logger(output_location))):
        return

    train_loader = datasets.registry.get(dataset_hparams, train=True)
    test_loader = datasets.registry.get(dataset_hparams, train=False)
    callbacks = standard_callbacks.standard_callbacks(
        training_hparams,
        train_loader,
        test_loader,
        start_step=start_step,
        verbose=verbose,
        evaluate_every_epoch=evaluate_every_epoch)
    train(training_hparams,
          model,
          train_loader,
          output_location,
          callbacks,
          start_step=start_step)
Esempio n. 2
0
    def setUp(self):
        super(TestStandardCallbacks, self).setUp()

        # Model hparams.
        self.hparams = models.registry.get_default_hparams('mnist_lenet_10_10')
        self.model = models.registry.get(self.hparams.model_hparams)

        # Dataset hparams.
        self.hparams.dataset_hparams.subsample_fraction = 0.01
        self.hparams.dataset_hparams.batch_size = 50
        self.train_loader = datasets.registry.get(self.hparams.dataset_hparams)
        self.test_loader = datasets.registry.get(self.hparams.dataset_hparams,
                                                 train=False)

        # Training hparams.
        self.hparams.training_hparams.training_steps = '3ep'

        # Get the callbacks.
        self.callbacks = standard_callbacks.standard_callbacks(
            self.hparams.training_hparams,
            self.train_loader,
            self.test_loader,
            eval_on_train=True,
            verbose=False)