Exemplo n.º 1
0
    def test_ResumeTrainingAfterInterruption(self):
        """Resuming training should match a run without interruption."""
        model, state = _get_linear_model()
        dataset = MockDatasetSource()
        optimizer = flax_training.create_optimizer(model, 0.0)
        training_dir = self.create_tempdir().full_path
        FLAGS.learning_rate = 0.01
        FLAGS.use_learning_rate_schedule = False
        # First we train for 10 epochs and get the logs.
        num_epochs = 10
        reference_run_dir = os.path.join(training_dir, 'reference')
        flax_training.train(optimizer, state, dataset, reference_run_dir,
                            num_epochs)
        records = tensorboard_event_to_dataframe(reference_run_dir)
        # In another directory (new experiment), we run the model for 4 epochs and
        # then for 10 epochs, to simulate an interruption.
        interrupted_run_dir = os.path.join(training_dir, 'interrupted')
        flax_training.train(optimizer, state, dataset, interrupted_run_dir, 4)
        flax_training.train(optimizer, state, dataset, interrupted_run_dir, 10)
        records_interrupted = tensorboard_event_to_dataframe(
            interrupted_run_dir)

        # Logs should match (order doesn't matter as it is a dataframe in tidy
        # format).
        def _make_hashable(row):
            return str(
                [e if not isinstance(e, float) else round(e, 5) for e in row])

        self.assertEqual(
            set([_make_hashable(e) for e in records_interrupted.values]),
            set([_make_hashable(e) for e in records.values]))
Exemplo n.º 2
0
def main(_):

    # As we gridsearch the weight decay and the learning rate, we add them to the
    # output directory path so that each model has its own directory to save the
    # results in. We also add the `run_seed` which is "gridsearched" on to
    # replicate an experiment several times.
    output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate),
                                     'wd_' + str(FLAGS.weight_decay),
                                     'seed_' + str(FLAGS.run_seed))

    output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)

    if not gfile.Exists(output_dir):
        gfile.MakeDirs(output_dir)

    num_devices = jax.local_device_count()
    assert FLAGS.batch_size % num_devices == 0
    local_batch_size = FLAGS.batch_size // num_devices
    info = 'Total batch size: {} ({} x {} replicas)'.format(
        FLAGS.batch_size, local_batch_size, num_devices)
    logging.info(info)

    if FLAGS.dataset.lower() == 'cifar10':
        dataset_source = dataset_source_lib.Cifar10(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'cifar100':
        dataset_source = dataset_source_lib.Cifar100(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'fashion_mnist':
        dataset_source = dataset_source_lib.FashionMnist(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset.lower() == 'svhn':
        dataset_source = dataset_source_lib.SVHN(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    else:
        raise ValueError(
            'Available datasets: cifar10(0), fashion_mnist, svhn.')

    if 'cifar' in FLAGS.dataset.lower() or 'svhn' in FLAGS.dataset.lower():
        image_size = 32
        num_channels = 3
    else:
        image_size = 28  # For Fashion Mnist
        num_channels = 1

    num_classes = 100 if FLAGS.dataset.lower() == 'cifar100' else 10
    model, state = load_model.get_model(FLAGS.model_name, local_batch_size,
                                        image_size, num_classes, num_channels)
    # Learning rate will be overwritten by the lr schedule, we set it to zero.
    optimizer = flax_training.create_optimizer(model, 0.0)

    flax_training.train(optimizer, state, dataset_source, output_dir,
                        FLAGS.num_epochs)
Exemplo n.º 3
0
 def test_TrainSimpleModel(self):
     """Model should reach 100% accuracy easily."""
     model, state = _get_linear_model()
     dataset = MockDatasetSource()
     num_epochs = 10
     optimizer = flax_training.create_optimizer(model, 0.0)
     training_dir = self.create_tempdir().full_path
     FLAGS.learning_rate = 0.01
     flax_training.train(optimizer, state, dataset, training_dir,
                         num_epochs)
     records = tensorboard_event_to_dataframe(training_dir)
     # Train error rate at the last step should be 0.
     records = records[records.metric == 'train_error_rate']
     records = records.sort_values('step')
     self.assertEqual(records.value.values[-1], 0.0)
Exemplo n.º 4
0
 def test_RecomputeTestLoss(self):
     """Recomputes the loss of the final model to check the value logged."""
     model, state = _get_linear_model()
     dataset = MockDatasetSource()
     num_epochs = 2
     optimizer = flax_training.create_optimizer(model, 0.0)
     training_dir = self.create_tempdir().full_path
     flax_training.train(optimizer, state, dataset, training_dir,
                         num_epochs)
     records = tensorboard_event_to_dataframe(training_dir)
     records = records[records.metric == 'test_loss']
     final_test_loss = records.sort_values('step').value.values[-1]
     # Loads final model and state.
     optimizer, state, _ = flax_training.restore_checkpoint(
         optimizer, state, os.path.join(training_dir, 'checkpoints'))
     # Averages over the first dimension as we will use only one device (no
     # pmapped operation.)
     optimizer = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0),
                              optimizer)
     state = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0), state)
     logits = optimizer.target(dataset.inputs)
     loss = flax_training.cross_entropy_loss(logits, dataset.labels)
     self.assertLess(abs(final_test_loss - loss), 1e-7)