def test_train_one_epoch(self): """Tests training loop over one epoch.""" trainer = training.Trainer(self._optimizer, self._model, self._state, self._dataset) with self.subTest(name='trainer_instantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train(self._num_epochs) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.deprecated.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss)
def test_train_one_epoch_pruning_local_schedule(self): """Tests training loop over one epoch with local pruning rate schedule.""" trainer = training.Trainer(self._optimizer, self._model, self._state, self._dataset) with self.subTest(name='trainer_instantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train( self._num_epochs, pruning_rate_fn={'MaskedModule_0': lambda _: 0.5}) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss)
if isinstance(pruning_schedule, collections.Mapping): pruning_rate_fn = { f'MaskedModule_{layer_num}': pruning_fn_p(schedule) for layer_num, schedule in pruning_schedule.items() } else: pruning_rate_fn = pruning_fn_p(pruning_schedule) else: pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.pruning_rate, steps_per_epoch) if jax.host_id() == 0: trainer = training.Trainer( optimizer, initial_model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer( optimizer, initial_model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, pruning_rate_fn=pruning_rate_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, )
def run_training(): """Trains a model.""" print('Logging to {}'.format(FLAGS.log_dir)) work_unit_id = uuid.uuid4() experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1, ) + dataset.shape model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, np.float32), ), num_classes=dataset.num_classes) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam(learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format( FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train(FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close()