コード例 #1
0
    def test_train_one_epoch(self):
        """Tests training loop over one epoch."""
        trainer = training.Trainer(self._optimizer, self._model, self._state,
                                   self._dataset)

        with self.subTest(name='trainer_instantiation'):
            self.assertIsInstance(trainer, training.Trainer)

        best_model, best_metrics = trainer.train(self._num_epochs)

        with self.subTest(name='best_model_type'):
            self.assertIsInstance(best_model, flax.deprecated.nn.Model)

        with self.subTest(name='train_accuracy'):
            self.assertBetween(best_metrics['train_accuracy'], 0., 1.)

        with self.subTest(name='train_avg_loss'):
            self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
                               self._max_loss)

        with self.subTest(name='step'):
            self.assertGreater(best_metrics['step'], 0)

        with self.subTest(name='cosine_distance'):
            self.assertBetween(best_metrics['cosine_distance'], 0., 1.)

        with self.subTest(name='cumulative_gradient_norm'):
            self.assertGreater(best_metrics['cumulative_gradient_norm'], 0)

        with self.subTest(name='test_accuracy'):
            self.assertBetween(best_metrics['test_accuracy'], 0., 1.)

        with self.subTest(name='test_avg_loss'):
            self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
                               self._max_loss)
コード例 #2
0
    def test_train_one_epoch_pruning_local_schedule(self):
        """Tests training loop over one epoch with local pruning rate schedule."""
        trainer = training.Trainer(self._optimizer, self._model, self._state,
                                   self._dataset)

        with self.subTest(name='trainer_instantiation'):
            self.assertIsInstance(trainer, training.Trainer)

        best_model, best_metrics = trainer.train(
            self._num_epochs,
            pruning_rate_fn={'MaskedModule_0': lambda _: 0.5})

        with self.subTest(name='best_model_type'):
            self.assertIsInstance(best_model, flax.nn.Model)

        with self.subTest(name='train_accuracy'):
            self.assertBetween(best_metrics['train_accuracy'], 0., 1.)

        with self.subTest(name='train_avg_loss'):
            self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
                               self._max_loss)

        with self.subTest(name='step'):
            self.assertGreater(best_metrics['step'], 0)

        with self.subTest(name='cosine_distance'):
            self.assertBetween(best_metrics['cosine_distance'], 0., 1.)

        with self.subTest(name='cumulative_gradient_norm'):
            self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.)

        with self.subTest(name='test_accuracy'):
            self.assertBetween(best_metrics['test_accuracy'], 0., 1.)

        with self.subTest(name='test_avg_loss'):
            self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
                               self._max_loss)
コード例 #3
0
ファイル: prune.py プロジェクト: yaelandau22/rigl
    if isinstance(pruning_schedule, collections.Mapping):
      pruning_rate_fn = {
          f'MaskedModule_{layer_num}': pruning_fn_p(schedule)
          for layer_num, schedule in pruning_schedule.items()
      }
    else:
      pruning_rate_fn = pruning_fn_p(pruning_schedule)
  else:
    pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule(
        FLAGS.pruning_rate, steps_per_epoch)

  if jax.host_id() == 0:
    trainer = training.Trainer(
        optimizer,
        initial_model,
        initial_state,
        dataset,
        rng,
        summary_writer=summary_writer,
    )
  else:
    trainer = training.Trainer(
        optimizer, initial_model, initial_state, dataset, rng)

  _, best_metrics = trainer.train(
      FLAGS.epochs,
      lr_fn=lr_fn,
      pruning_rate_fn=pruning_rate_fn,
      update_iter=FLAGS.update_iterations,
      update_epoch=FLAGS.update_epoch,
  )
コード例 #4
0
ファイル: train.py プロジェクト: tawawhite/rigl
def run_training():
    """Trains a model."""
    print('Logging to {}'.format(FLAGS.log_dir))
    work_unit_id = uuid.uuid4()
    experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

    logging.info('Saving experimental results to %s', experiment_dir)

    host_count = jax.host_count()
    local_device_count = jax.local_device_count()
    logging.info('Device count: %d, host count: %d, local device count: %d',
                 jax.device_count(), host_count, local_device_count)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(experiment_dir)

    dataset = dataset_factory.create_dataset(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.batch_size_test,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    logging.info('Training %s on the %s dataset...', FLAGS.model,
                 FLAGS.dataset)

    rng = jax.random.PRNGKey(FLAGS.random_seed)

    input_shape = (1, ) + dataset.shape
    model, initial_state = model_factory.create_model(
        FLAGS.model,
        rng, ((input_shape, np.float32), ),
        num_classes=dataset.num_classes)

    if FLAGS.optimizer == 'Adam':
        optimizer = flax.optim.Adam(learning_rate=FLAGS.lr,
                                    weight_decay=FLAGS.weight_decay)
    elif FLAGS.optimizer == 'Momentum':
        optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr,
                                        beta=FLAGS.momentum,
                                        weight_decay=FLAGS.weight_decay,
                                        nesterov=False)

    steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size

    if FLAGS.lr_schedule == 'constant':
        lr_fn = lr_schedule.create_constant_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch)
    elif FLAGS.lr_schedule == 'stepped':
        lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
        lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, lr_schedule_steps)
    elif FLAGS.lr_schedule == 'cosine':
        lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, FLAGS.epochs)
    else:
        raise ValueError('Unknown LR schedule type {}'.format(
            FLAGS.lr_schedule))

    if jax.host_id() == 0:
        trainer = training.Trainer(
            optimizer,
            model,
            initial_state,
            dataset,
            rng,
            summary_writer=summary_writer,
        )
    else:
        trainer = training.Trainer(optimizer, model, initial_state, dataset,
                                   rng)

    _, best_metrics = trainer.train(FLAGS.epochs,
                                    lr_fn=lr_fn,
                                    update_iter=FLAGS.update_iterations,
                                    update_epoch=FLAGS.update_epoch)

    logging.info('Best metrics: %s', str(best_metrics))

    if jax.host_id() == 0:
        for label, value in best_metrics.items():
            summary_writer.scalar('best/{}'.format(label), value,
                                  FLAGS.epochs * steps_per_epoch)
        summary_writer.close()