Ejemplo n.º 1
0
 def test_adam(self):
   init_fn, update_fn = optimizers.get_optimizer(
       ConfigDict({
           'optimizer': 'adam',
           'l2_decay_factor': None,
           'batch_size': 50,
           'total_accumulated_batch_size': 100,  # Use gradient accumulation.
           'opt_hparams': {
               'beta1': 0.9,
               'beta2': 0.999,
               'epsilon': 1e-7,
               'weight_decay': 0.0,
           }
       }))
   del update_fn
   optimizer_state = init_fn({'foo': jnp.ones(10)})
   # Test that we can extract 'count'.
   chex.assert_type(extract_field(optimizer_state, 'count'), int)
   # Test that we can extract 'nu'.
   chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,))
   # Test that we can extract 'mu'.
   chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,))
   # Test that attemptping to extract a nonexistent field "abc" returns None.
   chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
Ejemplo n.º 2
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          early_stopping_target_name=None,
          early_stopping_target_value=None,
          early_stopping_mode=None,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          callback_configs=None,
          external_checkpoint_path=None):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    early_stopping_target_name: A string naming the metric to use to perform
       early stopping. If this metric reaches the value
      `early_stopping_target_value`, training will stop. Must include the
      dataset split (ex: validation/error_rate).
    early_stopping_target_value: A float indicating the value at which to stop
      training.
    early_stopping_mode: One of "above" or "below", indicates if we should stop
      when the metric is above or below the threshold value. Example: if
      "above", then training will stop when
      `report[early_stopping_target_name] >= early_stopping_target_value`.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    callback_configs: List of configs specifying general callbacks to run
      during the eval phase. Empty list means no callbacks are run. See
      callbacks.py for details on what is expected in a config.
    external_checkpoint_path: (str) If this argument is set, we will load the
      optimizer_state, params, batch_stats, and training_metrics from the
      checkpoint at this location.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    # For logging / processing off the main thread
    pool = multiprocessing.pool.ThreadPool()

    if jax.process_index() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size
    if callback_configs is None:
        callback_configs = []

    # Maybe run the initializer.
    unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
        model.flax_module, initializer, model.loss_fn,
        hps.input_shape, hps.output_shape, hps, init_rng, init_logger,
        model.get_fake_batch(hps))

    if jax.process_index() == 0:
        utils.log_pytree_shape_and_statistics(unreplicated_params)
        logging.info('train_size: %d,', hps.train_size)

    # Note that global_step refers to the number of gradients calculations, not
    # the number of model updates. This means when using gradient accumulation,
    # one must supply configs where the number of steps are in units of gradient
    # calculations, not model updates, and in post processing one must divide
    # global_step by grad_accum_step_multiplier to get the number of updates.
    #
    # If using gradient accumulation, stretch the learning rate schedule by the
    # number of gradient calculations per weight update.
    stretch_factor = 1
    if hps.get('total_accumulated_batch_size') is not None:
        stretch_factor = hps.total_accumulated_batch_size // hps.batch_size
    lr_fn = schedules.get_schedule_fn(hps.lr_hparams,
                                      num_train_steps,
                                      stretch_factor=stretch_factor)

    optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(
        hps, model)
    unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)

    (unreplicated_metrics_state, metrics_update_fn,
     metrics_summary_fn) = None, None, None
    if training_metrics_config is not None:
        (metrics_init_fn, metrics_update_fn,
         metrics_summary_fn) = make_training_metrics(num_train_steps,
                                                     **training_metrics_config)
        unreplicated_metrics_state = metrics_init_fn(unreplicated_params)

    (optimizer_state, params, batch_stats, metrics_state, global_step,
     sum_train_cost, preemption_count,
     is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint(
         unreplicated_optimizer_state,
         unreplicated_params,
         unreplicated_batch_stats,
         unreplicated_metrics_state,
         train_dir=train_dir,
         external_checkpoint_path=external_checkpoint_path)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    update_fn = functools.partial(update,
                                  training_cost=model.training_cost,
                                  grad_clip=hps.get('grad_clip'),
                                  optimizer_update_fn=optimizer_update_fn,
                                  metrics_update_fn=metrics_update_fn)
    # in_axes = (
    #     optimizer_state = 0,
    #     params = 0,
    #     batch_stats = 0,
    #     metrics_state = 0,
    #     batch = 0,
    #     step = None,
    #     lr = None,
    #     rng = None,
    #     local_device_index = 0,
    #     running_train_cost = 0,
    #     training_cost,
    #     grad_clip,
    #     optimizer_update_fn,
    #     metrics_state_update_fn)
    # Also, we can donate buffers for 'optimizer', 'batch_stats',
    # 'batch' and 'training_metrics_state' for update's pmapped computation.
    update_pmapped = jax.pmap(update_fn,
                              axis_name='batch',
                              in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0),
                              donate_argnums=(0, 1, 2, 8))
    # During eval, we can donate the 'batch' buffer. We don't donate the
    # 'params' and 'batch_stats' buffers as we don't re-assign those values in
    # eval, we do that only in train.
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch,
                                      axis_name='batch',
                                      donate_argnums=(2, ))
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.process_index() == 0:
        trainer_utils.log_message('Starting training!', pool, xm_work_unit)

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)

    eval_callbacks = []
    rng, callback_rng = jax.random.split(rng)
    callback_rngs = jax.random.split(callback_rng, len(callback_configs))
    for callback_rng, config in zip(callback_rngs, callback_configs):
        eval_callback = callbacks.get_callback(config['callback_name'])(
            model, params, batch_stats, optimizer_state, dataset, hps, config,
            train_dir, callback_rng)
        eval_callbacks.append(eval_callback)

    train_iter = trainer_utils.prefetch_input_pipeline(
        train_iter, hps.num_device_prefetches)

    eval_start_time = start_time
    eval_start_step = start_step
    for _ in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation('train', step_num=global_step):
            # NOTE(dsuo): to properly profile each step, we must include batch
            # creation in the StepTraceContext (as opposed to putting `train_iter`
            # directly in the top-level for loop).
            batch = next(train_iter)

            if global_step in checkpoint_steps and jax.process_index() == 0:
                checkpoint.save_unreplicated_checkpoint_background(
                    checkpoint_dir,
                    optimizer_state,
                    params,
                    batch_stats,
                    metrics_state,
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    max_to_keep=None)
            lr = lr_fn(global_step)
            optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped(
                optimizer_state, params, batch_stats, metrics_state, batch,
                global_step, lr, rng, local_device_indices, sum_train_cost)
            global_step += 1
            # TODO(gdahl, gilmer): consider moving this test up.
            # NB: Since this test is after we increment global_step, having 0 in
            # eval_steps does nothing.
            if trainer_utils.should_eval(global_step, eval_frequency,
                                         eval_steps):
                train_steps_per_sec = (global_step - eval_start_step) / (
                    time.time() - eval_start_time)
                eval_start_step = global_step
                eval_start_time = time.time()
                batch_stats = trainer_utils.maybe_sync_batchnorm_stats(
                    batch_stats)
                report, eval_time = eval_metrics(params, batch_stats, dataset,
                                                 eval_num_batches,
                                                 eval_train_num_batches,
                                                 evaluate_batch_pmapped)
                mean_train_cost = sum_train_cost.mean().item() / max(
                    1, global_step - prev_eval_step)
                report.update(
                    learning_rate=float(lr),
                    global_step=global_step,
                    epoch=global_step * hps.batch_size // hps.train_size,
                    train_steps_per_sec=train_steps_per_sec,
                    overall_steps_per_sec=get_step_frequency(global_step),
                    eval_time=eval_time,
                    grad_norm=np.mean(grad_norm),
                    preemption_count=preemption_count,
                    train_cost=mean_train_cost)

                for eval_callback in eval_callbacks:
                    callback_metrics = eval_callback.run_eval(
                        params, batch_stats, optimizer_state, global_step)
                    if set(callback_metrics.keys()).intersection(
                            set(report.keys())):
                        raise ValueError(
                            'There was a collision between the callback'
                            'metrics and the standard eval metrics keys')
                    report.update(callback_metrics)
                yield report
                if jax.process_index() == 0:
                    trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                          train_steps_per_sec, num_train_steps,
                                          start_time, eval_frequency,
                                          eval_steps, eval_time)
                    trainer_utils.log_epoch_report(report, metrics_logger)
                    trainer_utils.maybe_log_training_metrics(
                        metrics_state, metrics_summary_fn, metrics_logger)
                    checkpoint.save_unreplicated_checkpoint_background(
                        train_dir, optimizer_state, params, batch_stats,
                        metrics_state, global_step, preemption_count,
                        sum_train_cost)
                sum_train_cost = jnp.zeros(jax.local_device_count())
                prev_eval_step = global_step

                early_stopping_condition = trainer_utils.check_for_early_stopping(
                    early_stopping_target_name, early_stopping_target_value,
                    early_stopping_mode, report)
                if early_stopping_condition:
                    comparison_string = '>=' if early_stopping_mode == 'above' else '<='
                    logging.info(
                        'Early stopping because metric %s=%f, reached the target value '
                        'of %s %f.', early_stopping_target_name,
                        report[early_stopping_target_name], comparison_string,
                        early_stopping_target_value)
                    return

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        train_steps_per_sec = (global_step - eval_start_step) / (
            time.time() - eval_start_time)
        batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(params, batch_stats, dataset,
                                         eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost.mean().item() / max(
            1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      train_steps_per_sec=train_steps_per_sec,
                      overall_steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      grad_norm=np.mean(grad_norm),
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.process_index() == 0:
            trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                  train_steps_per_sec, num_train_steps,
                                  start_time, eval_frequency, eval_steps,
                                  eval_time)
            trainer_utils.log_epoch_report(report, metrics_logger)
            trainer_utils.maybe_log_training_metrics(metrics_state,
                                                     metrics_summary_fn,
                                                     metrics_logger)
            checkpoint.save_unreplicated_checkpoint_background(
                train_dir, optimizer_state, params, batch_stats, metrics_state,
                global_step, preemption_count, sum_train_cost)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()
Ejemplo n.º 3
0
def eval_checkpoints(
    checkpoint_dir,
    hps,
    rng,
    eval_num_batches,
    model_cls,
    dataset_builder,
    dataset_meta_data,
    hessian_eval_config,
    min_global_step=None,
    max_global_step=None,
):
  """Evaluate the Hessian of the given checkpoints.

  Iterates over all checkpoints in the specified directory, loads the checkpoint
  then evaluates the Hessian on the given checkpoint. A list of dicts will be
  saved to cns at checkpoint_dir/hessian_eval_config['name'].

  Args:
    checkpoint_dir: Directory of checkpoints to load.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_num_batches: (int) The batch size used for evaluating on
      validation, and test sets. Set to None to evaluate on the whole test set.
    model_cls: One of the model classes (not an instance) defined in model_lib.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    dataset_meta_data: dict of meta_data about the dataset.
    hessian_eval_config: a dict specifying the configuration of the Hessian
      eval.
    min_global_step: Lower bound on what steps to filter checkpoints. Set to
      None to evaluate all checkpoints in the directory.
    max_global_step: Upper bound on what steps to filter checkpoints.
  """
  rng, init_rng = jax.random.split(rng)
  rng = jax.random.fold_in(rng, jax.process_index())
  rng, data_rng = jax.random.split(rng)

  initializer = initializers.get_initializer('noop')

  loss_name = 'cross_entropy'
  metrics_name = 'classification_metrics'
  model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

  # Maybe run the initializer.
  unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
      model.flax_module,
      initializer, model.loss_fn,
      hps.input_shape,
      hps.output_shape, hps, init_rng,
      None)

  # Fold in a the unreplicated batch_stats and rng into the loss used by
  # hessian eval.
  def batch_loss(params, batch_rng):
    batch, rng = batch_rng
    return model.training_cost(
        params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0]
  batch_stats = jax_utils.replicate(unreplicated_batch_stats)

  if jax.process_index() == 0:
    utils.log_pytree_shape_and_statistics(unreplicated_params)
    logging.info('train_size: %d,', hps.train_size)
    logging.info(hps)
    # Save the hessian computation hps to the experiment directory
    exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    if not gfile.exists(exp_dir):
      gfile.mkdir(exp_dir)
    if min_global_step == 0:
      hparams_fname = os.path.join(exp_dir, 'hparams.json')
      with gfile.GFile(hparams_fname, 'w') as f:
        f.write(hps.to_json())
      config_fname = os.path.join(exp_dir, 'hconfig.json')
      with gfile.GFile(config_fname, 'w') as f:
        f.write(json.dumps(hessian_eval_config))

  optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps)
  unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)
  # Note that we do not use the learning rate.
  # The optimizer state is a list of all the optax transformation states, and
  # we inject the learning rate into all states that will accept it.
  for state in unreplicated_optimizer_state:
    if (isinstance(state, optax.InjectHyperparamsState) and
        'learning_rate' in state.hyperparams):
      state.hyperparams['learning_rate'] = jax_utils.replicate(1.0)
  optimizer_state = jax_utils.replicate(unreplicated_optimizer_state)
  params = jax_utils.replicate(unreplicated_params)
  data_rng = jax.random.fold_in(data_rng, 0)

  assert hps.batch_size % (jax.device_count()) == 0
  dataset = dataset_builder(
      data_rng,
      hps.batch_size,
      eval_batch_size=hps.batch_size,  # eval iterators not used.
      hps=hps,
  )

  # pmap functions for the training loop
  evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')

  if jax.process_index() == 0:
    logging.info('Starting eval!')
    logging.info('Number of hosts: %d', jax.process_count())

  hessian_evaluator = hessian_eval.CurvatureEvaluator(
      params,
      hessian_eval_config,
      dataset=dataset,
      loss=batch_loss)
  if min_global_step is None:
    suffix = ''
  else:
    suffix = '{}_{}'.format(min_global_step, max_global_step)
  pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'],
                             suffix)
  logger = utils.MetricLogger(pytree_path=pytree_path)
  for checkpoint_path, step in iterate_checkpoints(checkpoint_dir,
                                                   min_global_step,
                                                   max_global_step):
    unreplicated_checkpoint_state = dict(
        params=unreplicated_params,
        optimizer_state=unreplicated_optimizer_state,
        batch_stats=unreplicated_batch_stats,
        global_step=0,
        preemption_count=0,
        sum_train_cost=0.0)
    ckpt = checkpoint.load_checkpoint(
        checkpoint_path,
        target=unreplicated_checkpoint_state)
    results, _ = checkpoint.replicate_checkpoint(
        ckpt,
        pytree_keys=['params', 'optimizer_state', 'batch_stats'])
    params = results['params']
    optimizer_state = results['optimizer_state']
    batch_stats = results['batch_stats']
    # pylint: disable=protected-access
    batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
    # pylint: enable=protected-access
    report, _ = trainer.eval_metrics(params, batch_stats, dataset,
                                     eval_num_batches, eval_num_batches,
                                     evaluate_batch_pmapped)
    if jax.process_index() == 0:
      logging.info('Global Step: %d', step)
      logging.info(report)
    row = {}
    grads, updates = [], []
    hess_evecs, cov_evecs = [], []
    stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum(
        params, step)
    row.update(stats)
    if hessian_eval_config[
        'compute_stats'] or hessian_eval_config['compute_interps']:
      grads, updates = hessian_evaluator.compute_dirs(
          params, optimizer_state, optimizer_update_fn)
    row.update(hessian_evaluator.evaluate_stats(params, grads,
                                                updates, hess_evecs,
                                                cov_evecs, step))
    row.update(hessian_evaluator.compute_interpolations(params, grads,
                                                        updates, hess_evecs,
                                                        cov_evecs, step))
    if jax.process_index() == 0:
      logger.append_pytree(row)
Ejemplo n.º 4
0
    def test_hessian_free_optimizer(self):
        """Tests the Hessian-free optimizer."""

        model_str = 'autoencoder'
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)

        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'

        input_shape = (2, 2, 1)
        output_shape = (4, )

        hps = copy.copy(model_hps)
        hps.update({
            'optimizer': 'hessian_free',
            'opt_hparams': {
                'weight_decay': 0.0,
            },
            'hid_sizes': [2],
            'activation_function': ['id'],
            'input_shape': input_shape,
            'output_shape': output_shape
        })

        model = model_cls(hps, {}, loss, metrics)

        inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]])
        targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape)))
        batch = {'inputs': inputs, 'targets': targets}

        def forward_fn(variables, inputs):
            logits = model.flax_module.apply(variables, inputs, train=True)
            return logits

        def opt_cost(variables):
            return model.loss_fn(forward_fn(variables, inputs), targets)

        init_fn, update_fn = optimizers.get_optimizer(hps, model)

        params = {
            'Dense_0': {
                'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2.,
                                                                      2.]]),
                'bias': jnp.array([0., 0.])
            },
            'Dense_1': {
                'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]),
                'bias': jnp.array([0., 0., 0., 0.])
            }
        }
        variables = {'params': params}

        grad_fn = jax.grad(opt_cost)
        grads = grad_fn(variables)['params']

        outputs = forward_fn(variables, batch['inputs'])

        n = inputs.shape[0]
        m = outputs.shape[-1]
        d = ravel_pytree(params)[0].shape[0]

        v = np.ones(d)

        state = init_fn(params)

        partial_forward_fn = partial(forward_fn, inputs=batch['inputs'])
        partial_loss_fn = partial(model.loss_fn, targets=batch['targets'])

        matmul_fn = partial(gvp, variables, outputs, state.inner_state.damping,
                            partial_forward_fn, partial_loss_fn)

        jacobian = jax.jacfwd(partial_forward_fn)(variables)['params']
        jacobian_tensor = np.concatenate(
            (jacobian['Dense_0']['bias'].reshape(
                n, m, -1), jacobian['Dense_0']['kernel'].reshape(
                    n, m, -1), jacobian['Dense_1']['bias'].reshape(n, m, -1),
             jacobian['Dense_1']['kernel'].reshape(n, m, -1)),
            axis=2)

        ggn_matrix = 0
        for i in range(n):
            jacobian_matrix = jacobian_tensor[i]
            hessian = jax.hessian(partial_loss_fn)(outputs[i, None])[0, :,
                                                                     0, :]
            ggn_matrix += np.transpose(
                jacobian_matrix) @ hessian @ jacobian_matrix
        ggn_matrix /= n
        ggn_matrix += state.inner_state.damping * np.identity(d)

        expected = ggn_matrix @ v

        # Test the gvp function
        self.assertAlmostEqual(jnp.linalg.norm(matmul_fn(v) - expected),
                               0,
                               places=4)

        update_pmapped = jax.pmap(update_fn,
                                  axis_name='batch',
                                  in_axes=(None, None, None, 0, None))

        batch_shard = data_utils.shard(batch)

        state.hyperparams['learning_rate'] = 1.0

        p, state = update_pmapped(grads, state, params, batch_shard, None)

        # Test the damping parameter update
        self.assertEqual(state.inner_state.damping, 3 / 2)

        # Test the search direction
        self.assertAlmostEqual(jnp.linalg.norm(
            ravel_pytree(p)[0] +
            jnp.linalg.inv(ggn_matrix) @ ravel_pytree(grads)[0]),
                               0,
                               places=4)
Ejemplo n.º 5
0
    def test_adam(self):
        """Test Adam preconditioning."""

        lr = 1e-3
        beta1 = 0.9
        beta2 = 0.999
        epsilon = 1e-7

        opt_hparams = FrozenConfigDict({
            'beta1': beta1,
            'beta2': beta2,
            'epsilon': epsilon
        })
        hparams = FrozenConfigDict({
            'optimizer': 'adam',
            'opt_hparams': opt_hparams,
            'l2_decay_factor': 0.0,
            'batch_size': 50,
            'total_accumulated_batch_size': 50,
        })

        init_fn, update_fn = optimizers.get_optimizer(hparams)

        params = {'foo': 1.0, 'bar': {'baz': 3.0}}
        gradients = [{
            'foo': 0.5,
            'bar': {
                'baz': 0.1
            }
        }, {
            'foo': 0.2,
            'bar': {
                'baz': 0.6
            }
        }]

        optimizer_state = init_fn(params)
        optimizer_state.base_state.hyperparams['learning_rate'] = lr

        for gradient in gradients:
            updates, optimizer_state = update_fn(gradient, optimizer_state,
                                                 params)
            params = optax.apply_updates(params, updates)

        # yes bias correction
        expected_preconditioner = _calculate_adam_preconditioner(
            gradients, beta2, epsilon, bias_correct=True)

        preconditioner = make_diag_preconditioner(
            'adam', opt_hparams, optimizer_state,
            FrozenConfigDict(dict(bias_correction=True)))

        self.assertTrue(
            pytree_allclose(expected_preconditioner, preconditioner))

        # no bias correction
        expected_preconditioner = _calculate_adam_preconditioner(
            gradients, beta2, epsilon, bias_correct=False)

        preconditioner = make_diag_preconditioner(
            'adam', opt_hparams, optimizer_state,
            FrozenConfigDict(dict(bias_correction=False)))

        self.assertTrue(
            pytree_allclose(expected_preconditioner, preconditioner))