Beispiel #1
0
    def setUp(self):
        super(TrainerTest, self).setUp()
        self.test_dir = tempfile.mkdtemp()
        rng = jax.random.PRNGKey(0)
        np.random.seed(0)
        self.feature_dim = 100
        num_outputs = 1
        self.batch_size = 32
        num_examples = 2048

        def create_model(key):
            flax_module = LinearModel(num_outputs=num_outputs)
            model_init_fn = jax.jit(
                functools.partial(flax_module.init, train=False))
            fake_input_batch = np.zeros((self.batch_size, self.feature_dim))
            init_dict = model_init_fn({'params': key}, fake_input_batch)
            params = init_dict['params']
            return flax_module, params

        flax_module, params = create_model(rng)
        # Linear model coefficients
        self.beta = params['Dense_0']['kernel']
        self.beta = self.beta.reshape((self.feature_dim, 1))
        self.beta = self.beta.astype(np.float32)

        optimizer_init_fn, self.optimizer_update_fn = optax.sgd(1.0)
        self.optimizer_state = jax_utils.replicate(optimizer_init_fn(params))
        self.params = jax_utils.replicate(params)

        data_class, self.feature, self.y = _get_synth_data(
            num_examples, self.feature_dim, num_outputs, self.batch_size)
        self.evaluator = hessian_eval.CurvatureEvaluator(
            self.params,
            CONFIG,
            dataset=data_class(),
            loss=functools.partial(_batch_square_loss, flax_module))
        # Computing the exact full-batch quantities from the linear model
        num_obs = CONFIG['num_batches'] * self.batch_size
        xb = self.feature[:num_obs, :]
        yb = self.y[:num_obs, :]
        self.fb_grad = _quad_grad(xb, yb, self.beta)
        self.hessian = 2 * np.dot(xb.T, xb) / num_obs
Beispiel #2
0
def set_up_hessian_eval(model, flax_module, batch_stats, dataset,
                        checkpoint_dir, hessian_eval_config):
    """Builds the CurvatureEvaluator object."""

    # First copy then unreplicate batch_stats. Note batch_stats doesn't affect the
    # forward pass in the hessian eval because we always run the model in training
    # However, we need to provide batch_stats for the model.training_cost API.
    # The copy is needed b/c the trainer will modify underlying arrays.
    batch_stats = jax.tree_map(lambda x: x[:][0], batch_stats)

    def batch_loss(module, batch_rng):
        batch, rng = batch_rng
        return model.training_cost(module,
                                   batch,
                                   batch_stats=batch_stats,
                                   dropout_rng=rng)[0]

    pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    logger = utils.MetricLogger(pytree_path=pytree_path)
    hessian_evaluator = hessian_eval.CurvatureEvaluator(flax_module,
                                                        hessian_eval_config,
                                                        dataset=dataset,
                                                        loss=batch_loss)
    return hessian_evaluator, logger
Beispiel #3
0
def eval_checkpoints(
    checkpoint_dir,
    hps,
    rng,
    eval_num_batches,
    model_cls,
    dataset_builder,
    dataset_meta_data,
    hessian_eval_config,
    min_global_step=None,
    max_global_step=None,
):
  """Evaluate the Hessian of the given checkpoints.

  Iterates over all checkpoints in the specified directory, loads the checkpoint
  then evaluates the Hessian on the given checkpoint. A list of dicts will be
  saved to cns at checkpoint_dir/hessian_eval_config['name'].

  Args:
    checkpoint_dir: Directory of checkpoints to load.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_num_batches: (int) The batch size used for evaluating on
      validation, and test sets. Set to None to evaluate on the whole test set.
    model_cls: One of the model classes (not an instance) defined in model_lib.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    dataset_meta_data: dict of meta_data about the dataset.
    hessian_eval_config: a dict specifying the configuration of the Hessian
      eval.
    min_global_step: Lower bound on what steps to filter checkpoints. Set to
      None to evaluate all checkpoints in the directory.
    max_global_step: Upper bound on what steps to filter checkpoints.
  """
  rng, init_rng = jax.random.split(rng)
  rng = jax.random.fold_in(rng, jax.process_index())
  rng, data_rng = jax.random.split(rng)

  initializer = initializers.get_initializer('noop')

  loss_name = 'cross_entropy'
  metrics_name = 'classification_metrics'
  model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

  # Maybe run the initializer.
  unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
      model.flax_module,
      initializer, model.loss_fn,
      hps.input_shape,
      hps.output_shape, hps, init_rng,
      None)

  # Fold in a the unreplicated batch_stats and rng into the loss used by
  # hessian eval.
  def batch_loss(params, batch_rng):
    batch, rng = batch_rng
    return model.training_cost(
        params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0]
  batch_stats = jax_utils.replicate(unreplicated_batch_stats)

  if jax.process_index() == 0:
    utils.log_pytree_shape_and_statistics(unreplicated_params)
    logging.info('train_size: %d,', hps.train_size)
    logging.info(hps)
    # Save the hessian computation hps to the experiment directory
    exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    if not gfile.exists(exp_dir):
      gfile.mkdir(exp_dir)
    if min_global_step == 0:
      hparams_fname = os.path.join(exp_dir, 'hparams.json')
      with gfile.GFile(hparams_fname, 'w') as f:
        f.write(hps.to_json())
      config_fname = os.path.join(exp_dir, 'hconfig.json')
      with gfile.GFile(config_fname, 'w') as f:
        f.write(json.dumps(hessian_eval_config))

  optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps)
  unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)
  # Note that we do not use the learning rate.
  # The optimizer state is a list of all the optax transformation states, and
  # we inject the learning rate into all states that will accept it.
  for state in unreplicated_optimizer_state:
    if (isinstance(state, optax.InjectHyperparamsState) and
        'learning_rate' in state.hyperparams):
      state.hyperparams['learning_rate'] = jax_utils.replicate(1.0)
  optimizer_state = jax_utils.replicate(unreplicated_optimizer_state)
  params = jax_utils.replicate(unreplicated_params)
  data_rng = jax.random.fold_in(data_rng, 0)

  assert hps.batch_size % (jax.device_count()) == 0
  dataset = dataset_builder(
      data_rng,
      hps.batch_size,
      eval_batch_size=hps.batch_size,  # eval iterators not used.
      hps=hps,
  )

  # pmap functions for the training loop
  evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')

  if jax.process_index() == 0:
    logging.info('Starting eval!')
    logging.info('Number of hosts: %d', jax.process_count())

  hessian_evaluator = hessian_eval.CurvatureEvaluator(
      params,
      hessian_eval_config,
      dataset=dataset,
      loss=batch_loss)
  if min_global_step is None:
    suffix = ''
  else:
    suffix = '{}_{}'.format(min_global_step, max_global_step)
  pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'],
                             suffix)
  logger = utils.MetricLogger(pytree_path=pytree_path)
  for checkpoint_path, step in iterate_checkpoints(checkpoint_dir,
                                                   min_global_step,
                                                   max_global_step):
    unreplicated_checkpoint_state = dict(
        params=unreplicated_params,
        optimizer_state=unreplicated_optimizer_state,
        batch_stats=unreplicated_batch_stats,
        global_step=0,
        preemption_count=0,
        sum_train_cost=0.0)
    ckpt = checkpoint.load_checkpoint(
        checkpoint_path,
        target=unreplicated_checkpoint_state)
    results, _ = checkpoint.replicate_checkpoint(
        ckpt,
        pytree_keys=['params', 'optimizer_state', 'batch_stats'])
    params = results['params']
    optimizer_state = results['optimizer_state']
    batch_stats = results['batch_stats']
    # pylint: disable=protected-access
    batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
    # pylint: enable=protected-access
    report, _ = trainer.eval_metrics(params, batch_stats, dataset,
                                     eval_num_batches, eval_num_batches,
                                     evaluate_batch_pmapped)
    if jax.process_index() == 0:
      logging.info('Global Step: %d', step)
      logging.info(report)
    row = {}
    grads, updates = [], []
    hess_evecs, cov_evecs = [], []
    stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum(
        params, step)
    row.update(stats)
    if hessian_eval_config[
        'compute_stats'] or hessian_eval_config['compute_interps']:
      grads, updates = hessian_evaluator.compute_dirs(
          params, optimizer_state, optimizer_update_fn)
    row.update(hessian_evaluator.evaluate_stats(params, grads,
                                                updates, hess_evecs,
                                                cov_evecs, step))
    row.update(hessian_evaluator.compute_interpolations(params, grads,
                                                        updates, hess_evecs,
                                                        cov_evecs, step))
    if jax.process_index() == 0:
      logger.append_pytree(row)
Beispiel #4
0
    def test_eval_hess_grad_overlap(self):
        """Test gradient overlap calculations."""

        if jax.devices()[0].platform == 'tpu':
            atol = 1e-3
            rtol = 0.1
        else:
            atol = 1e-5
            rtol = 1e-5

        # dimension of space
        n_params = 4
        num_batches = 5

        eval_config = hessian_eval.DEFAULT_EVAL_CONFIG
        eval_config['eval_hessian'] = False
        eval_config['eval_gradient_covariance'] = False
        eval_config['num_eigens'] = 0
        eval_config['num_lanczos_steps'] = n_params  # max iterations
        eval_config['num_batches'] = num_batches
        eval_config['num_eval_draws'] = 1

        key = jax.random.PRNGKey(0)
        key, split = jax.random.split(key)

        # Diagonal matrix values
        mat_diag = jax.random.normal(split, (n_params, ))

        def batches_gen():
            for _ in range(num_batches):
                yield None

        # Model
        class QuadraticLoss(nn.Module):
            """Loss function which only depends on parameters."""
            @nn.compact
            def __call__(self, x):
                del x
                w = self.param('w', jax.random.normal, (n_params, ))
                return jnp.sum((w**2) * mat_diag)

        flax_module = QuadraticLoss()

        def loss(params, _):
            # 1.0 is required but unused.
            return flax_module.apply({'params': params}, 1.0)

        # Model initialization.
        model_init_fn = jax.jit(flax_module.init)
        init_dict = model_init_fn({'params': key},
                                  np.zeros((num_batches, ), jnp.float32))
        params = init_dict['params']

        # replicate
        replicated_params = jax_utils.replicate(params)

        curve_eval = hessian_eval.CurvatureEvaluator(replicated_params,
                                                     eval_config,
                                                     loss,
                                                     dataset=None,
                                                     batches_gen=batches_gen)
        row, _, _ = curve_eval.evaluate_spectrum(replicated_params, 0)

        tridiag = row['tridiag_hess_grad_overlap']
        eigs_triag, vecs_triag = np.linalg.eigh(tridiag)

        # Test eigenvalues
        np.testing.assert_allclose(eigs_triag,
                                   2 * np.sort(mat_diag),
                                   atol=atol,
                                   rtol=rtol)

        # Compute overlaps
        weights_triag = vecs_triag[0, :]**2
        grad_true = 2 * params['w'] * mat_diag
        weight_idx = np.argsort(mat_diag)
        weights_true = (grad_true)**2 / jnp.dot(grad_true, grad_true)
        weights_true = weights_true[weight_idx]

        # Test overlaps
        np.testing.assert_allclose(weights_triag,
                                   weights_true,
                                   atol=atol,
                                   rtol=rtol)
    def test_block_hessian(self):
        """Test block_hessian code on a low rank factorization problem.

    See Example 1.2 in https://arxiv.org/abs/2202.00980.
    """
        full_dim = 10
        low_rank_dim = 3

        # Make the init unbalanced
        params = {'AA': {}, 'AB': {}}
        a_init_scale = 10.0
        b_init_scale = .1

        # We make params nested as some errors with flax.unfreeze were only
        # surfaced for nested dictionaries.
        params['AA']['inner'] = jnp.array(
            np.random.normal(scale=a_init_scale,
                             size=(full_dim, low_rank_dim)))
        params['AB']['inner'] = jnp.array(
            np.random.normal(scale=b_init_scale,
                             size=(full_dim, low_rank_dim)))

        # hessian eval pmaps by default, so replicate params even for cpu tests.
        rep_params = flax.jax_utils.replicate(params)

        # True matrix factorization
        true_a = jnp.array(np.random.normal(size=(full_dim, low_rank_dim)))
        true_b = jnp.array(np.random.normal(size=(full_dim, low_rank_dim)))
        y = jnp.dot(true_a, true_b.T)

        # Set up the mse loss to match the hessian API
        def loss(params, unused_batch):
            y_pred = jnp.dot(params['AA']['inner'], params['AB']['inner'].T)
            return jnp.sum((y_pred - y)**2) / 2

        # Fake batches_gen to match the hessian_eval_api.
        def batches_gen():
            yield flax.jax_utils.replicate(jnp.array(1))  # Match expected API.

        # Set up curvature evaluator
        eval_config = hessian_eval.DEFAULT_EVAL_CONFIG.copy()
        eval_config['block_hessian'] = True
        eval_config['param_partition_fn'] = 'outer_key'
        evaluator = hessian_eval.CurvatureEvaluator(rep_params,
                                                    eval_config,
                                                    batches_gen=batches_gen,
                                                    loss=loss)

        results, _, _ = evaluator.evaluate_spectrum(rep_params, step=0)
        a_max_eig = np.linalg.eigvalsh(
            np.dot(params['AB']['inner'], params['AB']['inner'].T)).max()
        b_max_eig = np.linalg.eigvalsh(
            np.dot(params['AA']['inner'], params['AA']['inner'].T)).max()

        self.assertAlmostEqual(a_max_eig,
                               results['block_hessian']['AA']['max_eig_hess'],
                               places=5)

        # True value is bigger than 1000, so need less places here.
        self.assertAlmostEqual(b_max_eig,
                               results['block_hessian']['AB']['max_eig_hess'],
                               places=2)