Ejemplo n.º 1
0
    def setUp(self):
        super().setUp()

        self._batch_size = 128  # Note: Tests are run on GPU/TPU.
        self._batch_size_test = 128
        self._shuffle_buffer_size = 1024
        self._rng = jax.random.PRNGKey(42)
        self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
        self._num_classes = 10
        self._num_epochs = 1

        self._learning_rate_fn = lambda _: 0.01
        self._weight_decay = 0.0001
        self._momentum = 0.9
        self._rng = jax.random.PRNGKey(42)

        self._min_loss = jnp.finfo(float).eps
        self._max_loss = 2.0 * math.log(self._num_classes)

        self._dataset_name = 'MNIST'
        self._model_name = 'MNIST_CNN'

        self._summarywriter = tensorboard.SummaryWriter('/tmp/')

        self._dataset = dataset_factory.create_dataset(
            self._dataset_name,
            self._batch_size,
            self._batch_size_test,
            shuffle_buffer_size=self._shuffle_buffer_size)

        self._model, self._state = model_factory.create_model(
            self._model_name,
            self._rng, (self._input_shape, ),
            num_classes=self._num_classes)

        self._optimizer = flax.optim.Momentum(
            learning_rate=self._learning_rate_fn(0),
            beta=self._momentum,
            weight_decay=self._weight_decay)
Ejemplo n.º 2
0
  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(experiment_dir)

  dataset = dataset_factory.create_dataset(
      FLAGS.dataset,
      FLAGS.batch_size,
      FLAGS.batch_size_test,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size)

  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)

  rng = jax.random.PRNGKey(FLAGS.random_seed)

  input_shape = (1,) + dataset.shape
  base_model, _ = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes)

  initial_model, initial_state = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes,
      masked_layer_indices=FLAGS.masked_layer_indices)

  if FLAGS.optimizer == 'Adam':
    optimizer = flax.optim.Adam(
        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
  elif FLAGS.optimizer == 'Momentum':
    optimizer = flax.optim.Momentum(
        learning_rate=FLAGS.lr,
        beta=FLAGS.momentum,
Ejemplo n.º 3
0
  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(experiment_dir)

  dataset = dataset_factory.create_dataset(
      FLAGS.dataset,
      FLAGS.batch_size,
      FLAGS.batch_size_test,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size)

  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)

  rng = jax.random.PRNGKey(FLAGS.random_seed)

  input_shape = (1,) + dataset.shape
  base_model, _ = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes)

  logging.info('Generating random mask based on model')

  # Re-initialize the RNG to maintain same training pattern (as in prune code).
  mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)

  mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,
                                  FLAGS.mask_sparsity)

  if jax.host_id() == 0:
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Mask stats: %s', str(mask_stats))

Ejemplo n.º 4
0
def run_training():
    """Trains a model."""
    print('Logging to {}'.format(FLAGS.log_dir))
    work_unit_id = uuid.uuid4()
    experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

    logging.info('Saving experimental results to %s', experiment_dir)

    host_count = jax.host_count()
    local_device_count = jax.local_device_count()
    logging.info('Device count: %d, host count: %d, local device count: %d',
                 jax.device_count(), host_count, local_device_count)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(experiment_dir)

    dataset = dataset_factory.create_dataset(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.batch_size_test,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    logging.info('Training %s on the %s dataset...', FLAGS.model,
                 FLAGS.dataset)

    rng = jax.random.PRNGKey(FLAGS.random_seed)

    input_shape = (1, ) + dataset.shape
    model, initial_state = model_factory.create_model(
        FLAGS.model,
        rng, ((input_shape, np.float32), ),
        num_classes=dataset.num_classes)

    if FLAGS.optimizer == 'Adam':
        optimizer = flax.optim.Adam(learning_rate=FLAGS.lr,
                                    weight_decay=FLAGS.weight_decay)
    elif FLAGS.optimizer == 'Momentum':
        optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr,
                                        beta=FLAGS.momentum,
                                        weight_decay=FLAGS.weight_decay,
                                        nesterov=False)

    steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size

    if FLAGS.lr_schedule == 'constant':
        lr_fn = lr_schedule.create_constant_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch)
    elif FLAGS.lr_schedule == 'stepped':
        lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
        lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, lr_schedule_steps)
    elif FLAGS.lr_schedule == 'cosine':
        lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
            FLAGS.lr, steps_per_epoch, FLAGS.epochs)
    else:
        raise ValueError('Unknown LR schedule type {}'.format(
            FLAGS.lr_schedule))

    if jax.host_id() == 0:
        trainer = training.Trainer(
            optimizer,
            model,
            initial_state,
            dataset,
            rng,
            summary_writer=summary_writer,
        )
    else:
        trainer = training.Trainer(optimizer, model, initial_state, dataset,
                                   rng)

    _, best_metrics = trainer.train(FLAGS.epochs,
                                    lr_fn=lr_fn,
                                    update_iter=FLAGS.update_iterations,
                                    update_epoch=FLAGS.update_epoch)

    logging.info('Best metrics: %s', str(best_metrics))

    if jax.host_id() == 0:
        for label, value in best_metrics.items():
            summary_writer.scalar('best/{}'.format(label), value,
                                  FLAGS.epochs * steps_per_epoch)
        summary_writer.close()
Ejemplo n.º 5
0
 def _create_model(self, model_name):
     return model_factory.create_model(model_name,
                                       self._rng, (self._input_shape, ),
                                       num_classes=self._num_classes)
Ejemplo n.º 6
0
  rng = jax.random.PRNGKey(FLAGS.random_seed)

  input_shape = (1,) + dataset.shape

  input_len = functools.reduce(operator.mul, dataset.shape)

  features = mnist_fc.feature_dim_for_param(
      input_len,
      FLAGS.param_count,
      FLAGS.depth)

  logging.info('Model Configuration: %s', str(features))

  base_model, _ = model_factory.create_model(
      MODEL,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes,
      features=features)

  model_param_count = utils.count_param(base_model, ('kernel',))

  logging.info(
      'Model Config: param.: %d, depth: %d. max width: %d, min width: %d',
      model_param_count, len(features), max(features), min(features))

  logging.info('Generating random mask based on model')

  # Re-initialize the RNG to maintain same training pattern (as in prune code).
  mask_rng = jax.random.PRNGKey(FLAGS.random_seed)
  mask = masked.shuffled_mask(
      base_model,