def _create_dataset(self, dataset_name): """Helper function for creating a dataset.""" return dataset_factory.create_dataset( dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size)
def setUp(self): super().setUp() self._batch_size = 128 # Note: Tests are run on GPU/TPU. self._batch_size_test = 128 self._shuffle_buffer_size = 1024 self._rng = jax.random.PRNGKey(42) self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._num_classes = 10 self._num_epochs = 1 self._learning_rate_fn = lambda _: 0.01 self._weight_decay = 0.0001 self._momentum = 0.9 self._rng = jax.random.PRNGKey(42) self._min_loss = jnp.finfo(float).eps self._max_loss = 2.0 * math.log(self._num_classes) self._dataset_name = 'MNIST' self._model_name = 'MNIST_CNN' self._summarywriter = tensorboard.SummaryWriter('/tmp/') self._dataset = dataset_factory.create_dataset( self._dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) self._model, self._state = model_factory.create_model( self._model_name, self._rng, (self._input_shape, ), num_classes=self._num_classes) self._optimizer = flax.optim.Momentum( learning_rate=self._learning_rate_fn(0), beta=self._momentum, weight_decay=self._weight_decay)
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) initial_model, initial_state = model_factory.create_model( FLAGS.model,
def run_training(): """Trains a model.""" print('Logging to {}'.format(FLAGS.log_dir)) work_unit_id = uuid.uuid4() experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1, ) + dataset.shape model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, np.float32), ), num_classes=dataset.num_classes) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam(learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format( FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train(FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close()