def setUp(self): super(TrainerTest, self).setUp() self.test_dir = tempfile.mkdtemp() rng = jax.random.PRNGKey(0) np.random.seed(0) self.feature_dim = 100 num_outputs = 1 self.batch_size = 32 num_examples = 2048 def create_model(key): flax_module = LinearModel(num_outputs=num_outputs) model_init_fn = jax.jit( functools.partial(flax_module.init, train=False)) fake_input_batch = np.zeros((self.batch_size, self.feature_dim)) init_dict = model_init_fn({'params': key}, fake_input_batch) params = init_dict['params'] return flax_module, params flax_module, params = create_model(rng) # Linear model coefficients self.beta = params['Dense_0']['kernel'] self.beta = self.beta.reshape((self.feature_dim, 1)) self.beta = self.beta.astype(np.float32) optimizer_init_fn, self.optimizer_update_fn = optax.sgd(1.0) self.optimizer_state = jax_utils.replicate(optimizer_init_fn(params)) self.params = jax_utils.replicate(params) data_class, self.feature, self.y = _get_synth_data( num_examples, self.feature_dim, num_outputs, self.batch_size) self.evaluator = hessian_eval.CurvatureEvaluator( self.params, CONFIG, dataset=data_class(), loss=functools.partial(_batch_square_loss, flax_module)) # Computing the exact full-batch quantities from the linear model num_obs = CONFIG['num_batches'] * self.batch_size xb = self.feature[:num_obs, :] yb = self.y[:num_obs, :] self.fb_grad = _quad_grad(xb, yb, self.beta) self.hessian = 2 * np.dot(xb.T, xb) / num_obs
def set_up_hessian_eval(model, flax_module, batch_stats, dataset, checkpoint_dir, hessian_eval_config): """Builds the CurvatureEvaluator object.""" # First copy then unreplicate batch_stats. Note batch_stats doesn't affect the # forward pass in the hessian eval because we always run the model in training # However, we need to provide batch_stats for the model.training_cost API. # The copy is needed b/c the trainer will modify underlying arrays. batch_stats = jax.tree_map(lambda x: x[:][0], batch_stats) def batch_loss(module, batch_rng): batch, rng = batch_rng return model.training_cost(module, batch, batch_stats=batch_stats, dropout_rng=rng)[0] pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name']) logger = utils.MetricLogger(pytree_path=pytree_path) hessian_evaluator = hessian_eval.CurvatureEvaluator(flax_module, hessian_eval_config, dataset=dataset, loss=batch_loss) return hessian_evaluator, logger
def eval_checkpoints( checkpoint_dir, hps, rng, eval_num_batches, model_cls, dataset_builder, dataset_meta_data, hessian_eval_config, min_global_step=None, max_global_step=None, ): """Evaluate the Hessian of the given checkpoints. Iterates over all checkpoints in the specified directory, loads the checkpoint then evaluates the Hessian on the given checkpoint. A list of dicts will be saved to cns at checkpoint_dir/hessian_eval_config['name']. Args: checkpoint_dir: Directory of checkpoints to load. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_num_batches: (int) The batch size used for evaluating on validation, and test sets. Set to None to evaluate on the whole test set. model_cls: One of the model classes (not an instance) defined in model_lib. dataset_builder: dataset builder returned by datasets.get_dataset. dataset_meta_data: dict of meta_data about the dataset. hessian_eval_config: a dict specifying the configuration of the Hessian eval. min_global_step: Lower bound on what steps to filter checkpoints. Set to None to evaluate all checkpoints in the directory. max_global_step: Upper bound on what steps to filter checkpoints. """ rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.process_index()) rng, data_rng = jax.random.split(rng) initializer = initializers.get_initializer('noop') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) # Maybe run the initializer. unreplicated_params, unreplicated_batch_stats = init_utils.initialize( model.flax_module, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, None) # Fold in a the unreplicated batch_stats and rng into the loss used by # hessian eval. def batch_loss(params, batch_rng): batch, rng = batch_rng return model.training_cost( params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0] batch_stats = jax_utils.replicate(unreplicated_batch_stats) if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) logging.info('train_size: %d,', hps.train_size) logging.info(hps) # Save the hessian computation hps to the experiment directory exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name']) if not gfile.exists(exp_dir): gfile.mkdir(exp_dir) if min_global_step == 0: hparams_fname = os.path.join(exp_dir, 'hparams.json') with gfile.GFile(hparams_fname, 'w') as f: f.write(hps.to_json()) config_fname = os.path.join(exp_dir, 'hconfig.json') with gfile.GFile(config_fname, 'w') as f: f.write(json.dumps(hessian_eval_config)) optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps) unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params) # Note that we do not use the learning rate. # The optimizer state is a list of all the optax transformation states, and # we inject the learning rate into all states that will accept it. for state in unreplicated_optimizer_state: if (isinstance(state, optax.InjectHyperparamsState) and 'learning_rate' in state.hyperparams): state.hyperparams['learning_rate'] = jax_utils.replicate(1.0) optimizer_state = jax_utils.replicate(unreplicated_optimizer_state) params = jax_utils.replicate(unreplicated_params) data_rng = jax.random.fold_in(data_rng, 0) assert hps.batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=hps.batch_size, # eval iterators not used. hps=hps, ) # pmap functions for the training loop evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch') if jax.process_index() == 0: logging.info('Starting eval!') logging.info('Number of hosts: %d', jax.process_count()) hessian_evaluator = hessian_eval.CurvatureEvaluator( params, hessian_eval_config, dataset=dataset, loss=batch_loss) if min_global_step is None: suffix = '' else: suffix = '{}_{}'.format(min_global_step, max_global_step) pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'], suffix) logger = utils.MetricLogger(pytree_path=pytree_path) for checkpoint_path, step in iterate_checkpoints(checkpoint_dir, min_global_step, max_global_step): unreplicated_checkpoint_state = dict( params=unreplicated_params, optimizer_state=unreplicated_optimizer_state, batch_stats=unreplicated_batch_stats, global_step=0, preemption_count=0, sum_train_cost=0.0) ckpt = checkpoint.load_checkpoint( checkpoint_path, target=unreplicated_checkpoint_state) results, _ = checkpoint.replicate_checkpoint( ckpt, pytree_keys=['params', 'optimizer_state', 'batch_stats']) params = results['params'] optimizer_state = results['optimizer_state'] batch_stats = results['batch_stats'] # pylint: disable=protected-access batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats) # pylint: enable=protected-access report, _ = trainer.eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_num_batches, evaluate_batch_pmapped) if jax.process_index() == 0: logging.info('Global Step: %d', step) logging.info(report) row = {} grads, updates = [], [] hess_evecs, cov_evecs = [], [] stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum( params, step) row.update(stats) if hessian_eval_config[ 'compute_stats'] or hessian_eval_config['compute_interps']: grads, updates = hessian_evaluator.compute_dirs( params, optimizer_state, optimizer_update_fn) row.update(hessian_evaluator.evaluate_stats(params, grads, updates, hess_evecs, cov_evecs, step)) row.update(hessian_evaluator.compute_interpolations(params, grads, updates, hess_evecs, cov_evecs, step)) if jax.process_index() == 0: logger.append_pytree(row)
def test_eval_hess_grad_overlap(self): """Test gradient overlap calculations.""" if jax.devices()[0].platform == 'tpu': atol = 1e-3 rtol = 0.1 else: atol = 1e-5 rtol = 1e-5 # dimension of space n_params = 4 num_batches = 5 eval_config = hessian_eval.DEFAULT_EVAL_CONFIG eval_config['eval_hessian'] = False eval_config['eval_gradient_covariance'] = False eval_config['num_eigens'] = 0 eval_config['num_lanczos_steps'] = n_params # max iterations eval_config['num_batches'] = num_batches eval_config['num_eval_draws'] = 1 key = jax.random.PRNGKey(0) key, split = jax.random.split(key) # Diagonal matrix values mat_diag = jax.random.normal(split, (n_params, )) def batches_gen(): for _ in range(num_batches): yield None # Model class QuadraticLoss(nn.Module): """Loss function which only depends on parameters.""" @nn.compact def __call__(self, x): del x w = self.param('w', jax.random.normal, (n_params, )) return jnp.sum((w**2) * mat_diag) flax_module = QuadraticLoss() def loss(params, _): # 1.0 is required but unused. return flax_module.apply({'params': params}, 1.0) # Model initialization. model_init_fn = jax.jit(flax_module.init) init_dict = model_init_fn({'params': key}, np.zeros((num_batches, ), jnp.float32)) params = init_dict['params'] # replicate replicated_params = jax_utils.replicate(params) curve_eval = hessian_eval.CurvatureEvaluator(replicated_params, eval_config, loss, dataset=None, batches_gen=batches_gen) row, _, _ = curve_eval.evaluate_spectrum(replicated_params, 0) tridiag = row['tridiag_hess_grad_overlap'] eigs_triag, vecs_triag = np.linalg.eigh(tridiag) # Test eigenvalues np.testing.assert_allclose(eigs_triag, 2 * np.sort(mat_diag), atol=atol, rtol=rtol) # Compute overlaps weights_triag = vecs_triag[0, :]**2 grad_true = 2 * params['w'] * mat_diag weight_idx = np.argsort(mat_diag) weights_true = (grad_true)**2 / jnp.dot(grad_true, grad_true) weights_true = weights_true[weight_idx] # Test overlaps np.testing.assert_allclose(weights_triag, weights_true, atol=atol, rtol=rtol)
def test_block_hessian(self): """Test block_hessian code on a low rank factorization problem. See Example 1.2 in https://arxiv.org/abs/2202.00980. """ full_dim = 10 low_rank_dim = 3 # Make the init unbalanced params = {'AA': {}, 'AB': {}} a_init_scale = 10.0 b_init_scale = .1 # We make params nested as some errors with flax.unfreeze were only # surfaced for nested dictionaries. params['AA']['inner'] = jnp.array( np.random.normal(scale=a_init_scale, size=(full_dim, low_rank_dim))) params['AB']['inner'] = jnp.array( np.random.normal(scale=b_init_scale, size=(full_dim, low_rank_dim))) # hessian eval pmaps by default, so replicate params even for cpu tests. rep_params = flax.jax_utils.replicate(params) # True matrix factorization true_a = jnp.array(np.random.normal(size=(full_dim, low_rank_dim))) true_b = jnp.array(np.random.normal(size=(full_dim, low_rank_dim))) y = jnp.dot(true_a, true_b.T) # Set up the mse loss to match the hessian API def loss(params, unused_batch): y_pred = jnp.dot(params['AA']['inner'], params['AB']['inner'].T) return jnp.sum((y_pred - y)**2) / 2 # Fake batches_gen to match the hessian_eval_api. def batches_gen(): yield flax.jax_utils.replicate(jnp.array(1)) # Match expected API. # Set up curvature evaluator eval_config = hessian_eval.DEFAULT_EVAL_CONFIG.copy() eval_config['block_hessian'] = True eval_config['param_partition_fn'] = 'outer_key' evaluator = hessian_eval.CurvatureEvaluator(rep_params, eval_config, batches_gen=batches_gen, loss=loss) results, _, _ = evaluator.evaluate_spectrum(rep_params, step=0) a_max_eig = np.linalg.eigvalsh( np.dot(params['AB']['inner'], params['AB']['inner'].T)).max() b_max_eig = np.linalg.eigvalsh( np.dot(params['AA']['inner'], params['AA']['inner'].T)).max() self.assertAlmostEqual(a_max_eig, results['block_hessian']['AA']['max_eig_hess'], places=5) # True value is bigger than 1000, so need less places here. self.assertAlmostEqual(b_max_eig, results['block_hessian']['AB']['max_eig_hess'], places=2)