def test_ResumeTrainingAfterInterruption(self): """Resuming training should match a run without interruption.""" model, state = _get_linear_model() dataset = MockDatasetSource() optimizer = flax_training.create_optimizer(model, 0.0) training_dir = self.create_tempdir().full_path FLAGS.learning_rate = 0.01 FLAGS.use_learning_rate_schedule = False # First we train for 10 epochs and get the logs. num_epochs = 10 reference_run_dir = os.path.join(training_dir, 'reference') flax_training.train(optimizer, state, dataset, reference_run_dir, num_epochs) records = tensorboard_event_to_dataframe(reference_run_dir) # In another directory (new experiment), we run the model for 4 epochs and # then for 10 epochs, to simulate an interruption. interrupted_run_dir = os.path.join(training_dir, 'interrupted') flax_training.train(optimizer, state, dataset, interrupted_run_dir, 4) flax_training.train(optimizer, state, dataset, interrupted_run_dir, 10) records_interrupted = tensorboard_event_to_dataframe( interrupted_run_dir) # Logs should match (order doesn't matter as it is a dataframe in tidy # format). def _make_hashable(row): return str( [e if not isinstance(e, float) else round(e, 5) for e in row]) self.assertEqual( set([_make_hashable(e) for e in records_interrupted.values]), set([_make_hashable(e) for e in records.values]))
def main(_): # As we gridsearch the weight decay and the learning rate, we add them to the # output directory path so that each model has its own directory to save the # results in. We also add the `run_seed` which is "gridsearched" on to # replicate an experiment several times. output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate), 'wd_' + str(FLAGS.weight_decay), 'seed_' + str(FLAGS.run_seed)) output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix) if not gfile.Exists(output_dir): gfile.MakeDirs(output_dir) num_devices = jax.local_device_count() assert FLAGS.batch_size % num_devices == 0 local_batch_size = FLAGS.batch_size // num_devices info = 'Total batch size: {} ({} x {} replicas)'.format( FLAGS.batch_size, local_batch_size, num_devices) logging.info(info) if FLAGS.dataset.lower() == 'cifar10': dataset_source = dataset_source_lib.Cifar10( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'cifar100': dataset_source = dataset_source_lib.Cifar100( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'fashion_mnist': dataset_source = dataset_source_lib.FashionMnist( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset.lower() == 'svhn': dataset_source = dataset_source_lib.SVHN( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) else: raise ValueError( 'Available datasets: cifar10(0), fashion_mnist, svhn.') if 'cifar' in FLAGS.dataset.lower() or 'svhn' in FLAGS.dataset.lower(): image_size = 32 num_channels = 3 else: image_size = 28 # For Fashion Mnist num_channels = 1 num_classes = 100 if FLAGS.dataset.lower() == 'cifar100' else 10 model, state = load_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes, num_channels) # Learning rate will be overwritten by the lr schedule, we set it to zero. optimizer = flax_training.create_optimizer(model, 0.0) flax_training.train(optimizer, state, dataset_source, output_dir, FLAGS.num_epochs)
def test_TrainSimpleModel(self): """Model should reach 100% accuracy easily.""" model, state = _get_linear_model() dataset = MockDatasetSource() num_epochs = 10 optimizer = flax_training.create_optimizer(model, 0.0) training_dir = self.create_tempdir().full_path FLAGS.learning_rate = 0.01 flax_training.train(optimizer, state, dataset, training_dir, num_epochs) records = tensorboard_event_to_dataframe(training_dir) # Train error rate at the last step should be 0. records = records[records.metric == 'train_error_rate'] records = records.sort_values('step') self.assertEqual(records.value.values[-1], 0.0)
def test_RecomputeTestLoss(self): """Recomputes the loss of the final model to check the value logged.""" model, state = _get_linear_model() dataset = MockDatasetSource() num_epochs = 2 optimizer = flax_training.create_optimizer(model, 0.0) training_dir = self.create_tempdir().full_path flax_training.train(optimizer, state, dataset, training_dir, num_epochs) records = tensorboard_event_to_dataframe(training_dir) records = records[records.metric == 'test_loss'] final_test_loss = records.sort_values('step').value.values[-1] # Loads final model and state. optimizer, state, _ = flax_training.restore_checkpoint( optimizer, state, os.path.join(training_dir, 'checkpoints')) # Averages over the first dimension as we will use only one device (no # pmapped operation.) optimizer = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0), optimizer) state = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0), state) logits = optimizer.target(dataset.inputs) loss = flax_training.cross_entropy_loss(logits, dataset.labels) self.assertLess(abs(final_test_loss - loss), 1e-7)