Esempio n. 1
0
 def _create_dataset(self, dataset_name):
   """Helper function for creating a dataset."""
   return dataset_factory.create_dataset(
       dataset_name,
       self._batch_size,
       self._batch_size_test,
       shuffle_buffer_size=self._shuffle_buffer_size)
Esempio n. 2
0
    def setUp(self):
        super().setUp()

        self._batch_size = 128  # Note: Tests are run on GPU/TPU.
        self._batch_size_test = 128
        self._shuffle_buffer_size = 1024
        self._rng = jax.random.PRNGKey(42)
        self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
        self._num_classes = 10
        self._num_epochs = 1

        self._learning_rate_fn = lambda _: 0.01
        self._weight_decay = 0.0001
        self._momentum = 0.9
        self._rng = jax.random.PRNGKey(42)

        self._min_loss = jnp.finfo(float).eps
        self._max_loss = 2.0 * math.log(self._num_classes)

        self._dataset_name = 'MNIST'
        self._model_name = 'MNIST_CNN'

        self._summarywriter = tensorboard.SummaryWriter('/tmp/')

        self._dataset = dataset_factory.create_dataset(
            self._dataset_name,
            self._batch_size,
            self._batch_size_test,
            shuffle_buffer_size=self._shuffle_buffer_size)

        self._model, self._state = model_factory.create_model(
            self._model_name,
            self._rng, (self._input_shape, ),
            num_classes=self._num_classes)

        self._optimizer = flax.optim.Momentum(
            learning_rate=self._learning_rate_fn(0),
            beta=self._momentum,
            weight_decay=self._weight_decay)
Esempio n. 3
0
  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

  logging.info('Saving experimental results to %s', experiment_dir)

  host_count = jax.host_count()
  local_device_count = jax.local_device_count()
  logging.info('Device count: %d, host count: %d, local device count: %d',
               jax.device_count(), host_count, local_device_count)

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(experiment_dir)

  dataset = dataset_factory.create_dataset(
      FLAGS.dataset,
      FLAGS.batch_size,
      FLAGS.batch_size_test,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size)

  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)

  rng = jax.random.PRNGKey(FLAGS.random_seed)

  input_shape = (1,) + dataset.shape
  base_model, _ = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes)

  initial_model, initial_state = model_factory.create_model(
      FLAGS.model,
Esempio n. 4
0
def run_training():
    """Trains a model."""
    print('Logging to {}'.format(FLAGS.log_dir))
    work_unit_id = uuid.uuid4()
    experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

    logging.info('Saving experimental results to %s', experiment_dir)

    host_count = jax.host_count()
    local_device_count = jax.local_device_count()
    logging.info('Device count: %d, host count: %d, local device count: %d',
                 jax.device_count(), host_count, local_device_count)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(experiment_dir)

    dataset = dataset_factory.create_dataset(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.batch_size_test,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    logging.info('Training %s on the %s dataset...', FLAGS.model,
                 FLAGS.dataset)

    rng = jax.random.PRNGKey(FLAGS.random_seed)

    input_shape = (1, ) + dataset.shape
    model, initial_state = model_factory.create_model(
        FLAGS.model,
        rng, ((input_shape, np.float32), ),
        num_classes=dataset.num_classes)

    if FLAGS.optimizer == 'Adam':
        optimizer = flax.optim.Adam(learning_rate=FLAGS.lr,
                                    weight_decay=FLAGS.weight_decay)
    elif FLAGS.optimizer == 'Momentum':
        optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr,
                                        beta=FLAGS.momentum,
                                        weight_decay=FLAGS.weight_decay,
                                        nesterov=False)

    steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size

    if FLAGS.lr_schedule == 'constant':
        lr_fn = lr_schedule.create_constant_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch)
    elif FLAGS.lr_schedule == 'stepped':
        lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
        lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, lr_schedule_steps)
    elif FLAGS.lr_schedule == 'cosine':
        lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, FLAGS.epochs)
    else:
        raise ValueError('Unknown LR schedule type {}'.format(
            FLAGS.lr_schedule))

    if jax.host_id() == 0:
        trainer = training.Trainer(
            optimizer,
            model,
            initial_state,
            dataset,
            rng,
            summary_writer=summary_writer,
        )
    else:
        trainer = training.Trainer(optimizer, model, initial_state, dataset,
                                   rng)

    _, best_metrics = trainer.train(FLAGS.epochs,
                                    lr_fn=lr_fn,
                                    update_iter=FLAGS.update_iterations,
                                    update_epoch=FLAGS.update_epoch)

    logging.info('Best metrics: %s', str(best_metrics))

    if jax.host_id() == 0:
        for label, value in best_metrics.items():
            summary_writer.scalar('best/{}'.format(label), value,
                                  FLAGS.epochs * steps_per_epoch)
        summary_writer.close()