def test_adam(self): init_fn, update_fn = optimizers.get_optimizer( ConfigDict({ 'optimizer': 'adam', 'l2_decay_factor': None, 'batch_size': 50, 'total_accumulated_batch_size': 100, # Use gradient accumulation. 'opt_hparams': { 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-7, 'weight_decay': 0.0, } })) del update_fn optimizer_state = init_fn({'foo': jnp.ones(10)}) # Test that we can extract 'count'. chex.assert_type(extract_field(optimizer_state, 'count'), int) # Test that we can extract 'nu'. chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,)) # Test that we can extract 'mu'. chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,)) # Test that attemptping to extract a nonexistent field "abc" returns None. chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
def train(train_dir, model, dataset_builder, initializer, num_train_steps, hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, early_stopping_target_name=None, early_stopping_target_value=None, early_stopping_mode=None, eval_steps=None, metrics_logger=None, init_logger=None, training_metrics_config=None, callback_configs=None, external_checkpoint_path=None): """Main training loop. Trains the given network on the specified dataset for the given number of epochs. Saves the training curve in train_dir/r=3/results.tsv. Args: train_dir: (str) Path of the training directory. model: (BaseModel) Model object to be trained. dataset_builder: dataset builder returned by datasets.get_dataset. initializer: Must have API as defined in initializers.py num_train_steps: (int) Number of steps to train on. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_batch_size: the evaluation batch size. If None, use hps.batch_size. eval_num_batches: (int) The number of batches used for evaluating on validation and test sets. Set to None to evaluate on the whole train set. eval_train_num_batches: (int) The number of batches for evaluating on train. Set to None to evaluate on the whole training set. eval_frequency: (int) Evaluate every k steps. checkpoint_steps: List of integers indicating special steps to save checkpoints at. These checkpoints do not get used for preemption recovery. early_stopping_target_name: A string naming the metric to use to perform early stopping. If this metric reaches the value `early_stopping_target_value`, training will stop. Must include the dataset split (ex: validation/error_rate). early_stopping_target_value: A float indicating the value at which to stop training. early_stopping_mode: One of "above" or "below", indicates if we should stop when the metric is above or below the threshold value. Example: if "above", then training will stop when `report[early_stopping_target_name] >= early_stopping_target_value`. eval_steps: List of integers indicating which steps to perform evals. If provided, eval_frequency will be ignored. Performing an eval implies saving a checkpoint that will be used to resume training in the case of preemption. metrics_logger: Used to log all eval metrics during training. See utils.MetricLogger for API definition. init_logger: Used for black box initializers that have learning curves. training_metrics_config: Dict specifying the configuration of the training_metrics_grabber. Set to None to skip logging of advanced training metrics. callback_configs: List of configs specifying general callbacks to run during the eval phase. Empty list means no callbacks are run. See callbacks.py for details on what is expected in a config. external_checkpoint_path: (str) If this argument is set, we will load the optimizer_state, params, batch_stats, and training_metrics from the checkpoint at this location. Yields: metrics: A dictionary of all eval metrics from the given epoch. """ # NOTE: the initialization RNG should *not* be per-host, as this will create # different sets of weights per host. However, all other RNGs should be # per-host. # TODO(znado,gilmer,gdahl): implement replicating the same initialization # across hosts. rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.process_index()) rng, data_rng = jax.random.split(rng) # only used if checkpoints_steps is non-empty. checkpoint_dir = os.path.join(train_dir, 'checkpoints') # For logging / processing off the main thread pool = multiprocessing.pool.ThreadPool() if jax.process_index() == 0: logging.info('Let the training begin!') logging.info('Dataset input shape: %r', hps.input_shape) logging.info('Hyperparameters: %s', hps) if eval_batch_size is None: eval_batch_size = hps.batch_size if callback_configs is None: callback_configs = [] # Maybe run the initializer. unreplicated_params, unreplicated_batch_stats = init_utils.initialize( model.flax_module, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, init_logger, model.get_fake_batch(hps)) if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) logging.info('train_size: %d,', hps.train_size) # Note that global_step refers to the number of gradients calculations, not # the number of model updates. This means when using gradient accumulation, # one must supply configs where the number of steps are in units of gradient # calculations, not model updates, and in post processing one must divide # global_step by grad_accum_step_multiplier to get the number of updates. # # If using gradient accumulation, stretch the learning rate schedule by the # number of gradient calculations per weight update. stretch_factor = 1 if hps.get('total_accumulated_batch_size') is not None: stretch_factor = hps.total_accumulated_batch_size // hps.batch_size lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps, stretch_factor=stretch_factor) optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer( hps, model) unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params) (unreplicated_metrics_state, metrics_update_fn, metrics_summary_fn) = None, None, None if training_metrics_config is not None: (metrics_init_fn, metrics_update_fn, metrics_summary_fn) = make_training_metrics(num_train_steps, **training_metrics_config) unreplicated_metrics_state = metrics_init_fn(unreplicated_params) (optimizer_state, params, batch_stats, metrics_state, global_step, sum_train_cost, preemption_count, is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint( unreplicated_optimizer_state, unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, train_dir=train_dir, external_checkpoint_path=external_checkpoint_path) if is_restored: preemption_count += 1 # Fold the restored step into the dataset RNG so that we will get a # different shuffle each time we restore, so that we do not repeat a # previous dataset ordering again after restoring. This is not the only # difference in shuffling each pre-emption, because we often times reshuffle # the input files each time in a non-deterministic manner. # # Note that if we are pre-empted more than once per epoch then we will # retrain more on the beginning of the training split, because each time we # restore we refill the shuffle buffer with the first `shuffle_buffer_size` # elements from the training split to continue training. # # Also note that for evaluating on the training split, because we are # reshuffling each time, we will get a new eval_train split each time we are # pre-empted. data_rng = jax.random.fold_in(data_rng, global_step) assert hps.batch_size % (jax.device_count()) == 0 assert eval_batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=eval_batch_size, hps=hps, ) update_fn = functools.partial(update, training_cost=model.training_cost, grad_clip=hps.get('grad_clip'), optimizer_update_fn=optimizer_update_fn, metrics_update_fn=metrics_update_fn) # in_axes = ( # optimizer_state = 0, # params = 0, # batch_stats = 0, # metrics_state = 0, # batch = 0, # step = None, # lr = None, # rng = None, # local_device_index = 0, # running_train_cost = 0, # training_cost, # grad_clip, # optimizer_update_fn, # metrics_state_update_fn) # Also, we can donate buffers for 'optimizer', 'batch_stats', # 'batch' and 'training_metrics_state' for update's pmapped computation. update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0), donate_argnums=(0, 1, 2, 8)) # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch', donate_argnums=(2, )) start_time = time.time() start_step = global_step prev_eval_step = start_step def get_step_frequency(cur_step): return float(cur_step - start_step) / (time.time() - start_time) if jax.process_index() == 0: trainer_utils.log_message('Starting training!', pool, xm_work_unit) # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. local_device_indices = np.arange(jax.local_device_count()) # Start at the resumed step and continue until we have finished the number of # training steps. If building a dataset iterator using a tf.data.Dataset, in # the case of a batch size that does not evenly divide the training dataset # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset # then the final batch in this iterator will be a partial batch. However, if # `drop_remainer=False`, then this iterator will always return batches of the # same size, and the final batch will have elements from the start of the # (num_epochs + 1)-th epoch. train_iter = itertools.islice(dataset.train_iterator_fn(), global_step, num_train_steps) eval_callbacks = [] rng, callback_rng = jax.random.split(rng) callback_rngs = jax.random.split(callback_rng, len(callback_configs)) for callback_rng, config in zip(callback_rngs, callback_configs): eval_callback = callbacks.get_callback(config['callback_name'])( model, params, batch_stats, optimizer_state, dataset, hps, config, train_dir, callback_rng) eval_callbacks.append(eval_callback) train_iter = trainer_utils.prefetch_input_pipeline( train_iter, hps.num_device_prefetches) eval_start_time = start_time eval_start_step = start_step for _ in range(start_step, num_train_steps): with jax.profiler.StepTraceAnnotation('train', step_num=global_step): # NOTE(dsuo): to properly profile each step, we must include batch # creation in the StepTraceContext (as opposed to putting `train_iter` # directly in the top-level for loop). batch = next(train_iter) if global_step in checkpoint_steps and jax.process_index() == 0: checkpoint.save_unreplicated_checkpoint_background( checkpoint_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost, max_to_keep=None) lr = lr_fn(global_step) optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped( optimizer_state, params, batch_stats, metrics_state, batch, global_step, lr, rng, local_device_indices, sum_train_cost) global_step += 1 # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment global_step, having 0 in # eval_steps does nothing. if trainer_utils.should_eval(global_step, eval_frequency, eval_steps): train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) eval_start_step = global_step eval_start_time = time.time() batch_stats = trainer_utils.maybe_sync_batchnorm_stats( batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update( learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) for eval_callback in eval_callbacks: callback_metrics = eval_callback.run_eval( params, batch_stats, optimizer_state, global_step) if set(callback_metrics.keys()).intersection( set(report.keys())): raise ValueError( 'There was a collision between the callback' 'metrics and the standard eval metrics keys') report.update(callback_metrics) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics( metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) sum_train_cost = jnp.zeros(jax.local_device_count()) prev_eval_step = global_step early_stopping_condition = trainer_utils.check_for_early_stopping( early_stopping_target_name, early_stopping_target_value, early_stopping_mode, report) if early_stopping_condition: comparison_string = '>=' if early_stopping_mode == 'above' else '<=' logging.info( 'Early stopping because metric %s=%f, reached the target value ' 'of %s %f.', early_stopping_target_name, report[early_stopping_target_name], comparison_string, early_stopping_target_value) return # Always log and checkpoint on host 0 at the end of training. # If we moved where in the loop body evals happen then we would not need this # test. if prev_eval_step != num_train_steps: train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) lr = lr_fn(global_step) # Correct the average for the final partial epoch. mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update(learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics(metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) # To make sure the last checkpoint was correctly saved. checkpoint.wait_for_checkpoint_save()
def eval_checkpoints( checkpoint_dir, hps, rng, eval_num_batches, model_cls, dataset_builder, dataset_meta_data, hessian_eval_config, min_global_step=None, max_global_step=None, ): """Evaluate the Hessian of the given checkpoints. Iterates over all checkpoints in the specified directory, loads the checkpoint then evaluates the Hessian on the given checkpoint. A list of dicts will be saved to cns at checkpoint_dir/hessian_eval_config['name']. Args: checkpoint_dir: Directory of checkpoints to load. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_num_batches: (int) The batch size used for evaluating on validation, and test sets. Set to None to evaluate on the whole test set. model_cls: One of the model classes (not an instance) defined in model_lib. dataset_builder: dataset builder returned by datasets.get_dataset. dataset_meta_data: dict of meta_data about the dataset. hessian_eval_config: a dict specifying the configuration of the Hessian eval. min_global_step: Lower bound on what steps to filter checkpoints. Set to None to evaluate all checkpoints in the directory. max_global_step: Upper bound on what steps to filter checkpoints. """ rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.process_index()) rng, data_rng = jax.random.split(rng) initializer = initializers.get_initializer('noop') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) # Maybe run the initializer. unreplicated_params, unreplicated_batch_stats = init_utils.initialize( model.flax_module, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, None) # Fold in a the unreplicated batch_stats and rng into the loss used by # hessian eval. def batch_loss(params, batch_rng): batch, rng = batch_rng return model.training_cost( params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0] batch_stats = jax_utils.replicate(unreplicated_batch_stats) if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) logging.info('train_size: %d,', hps.train_size) logging.info(hps) # Save the hessian computation hps to the experiment directory exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name']) if not gfile.exists(exp_dir): gfile.mkdir(exp_dir) if min_global_step == 0: hparams_fname = os.path.join(exp_dir, 'hparams.json') with gfile.GFile(hparams_fname, 'w') as f: f.write(hps.to_json()) config_fname = os.path.join(exp_dir, 'hconfig.json') with gfile.GFile(config_fname, 'w') as f: f.write(json.dumps(hessian_eval_config)) optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps) unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params) # Note that we do not use the learning rate. # The optimizer state is a list of all the optax transformation states, and # we inject the learning rate into all states that will accept it. for state in unreplicated_optimizer_state: if (isinstance(state, optax.InjectHyperparamsState) and 'learning_rate' in state.hyperparams): state.hyperparams['learning_rate'] = jax_utils.replicate(1.0) optimizer_state = jax_utils.replicate(unreplicated_optimizer_state) params = jax_utils.replicate(unreplicated_params) data_rng = jax.random.fold_in(data_rng, 0) assert hps.batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=hps.batch_size, # eval iterators not used. hps=hps, ) # pmap functions for the training loop evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch') if jax.process_index() == 0: logging.info('Starting eval!') logging.info('Number of hosts: %d', jax.process_count()) hessian_evaluator = hessian_eval.CurvatureEvaluator( params, hessian_eval_config, dataset=dataset, loss=batch_loss) if min_global_step is None: suffix = '' else: suffix = '{}_{}'.format(min_global_step, max_global_step) pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'], suffix) logger = utils.MetricLogger(pytree_path=pytree_path) for checkpoint_path, step in iterate_checkpoints(checkpoint_dir, min_global_step, max_global_step): unreplicated_checkpoint_state = dict( params=unreplicated_params, optimizer_state=unreplicated_optimizer_state, batch_stats=unreplicated_batch_stats, global_step=0, preemption_count=0, sum_train_cost=0.0) ckpt = checkpoint.load_checkpoint( checkpoint_path, target=unreplicated_checkpoint_state) results, _ = checkpoint.replicate_checkpoint( ckpt, pytree_keys=['params', 'optimizer_state', 'batch_stats']) params = results['params'] optimizer_state = results['optimizer_state'] batch_stats = results['batch_stats'] # pylint: disable=protected-access batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats) # pylint: enable=protected-access report, _ = trainer.eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_num_batches, evaluate_batch_pmapped) if jax.process_index() == 0: logging.info('Global Step: %d', step) logging.info(report) row = {} grads, updates = [], [] hess_evecs, cov_evecs = [], [] stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum( params, step) row.update(stats) if hessian_eval_config[ 'compute_stats'] or hessian_eval_config['compute_interps']: grads, updates = hessian_evaluator.compute_dirs( params, optimizer_state, optimizer_update_fn) row.update(hessian_evaluator.evaluate_stats(params, grads, updates, hess_evecs, cov_evecs, step)) row.update(hessian_evaluator.compute_interpolations(params, grads, updates, hess_evecs, cov_evecs, step)) if jax.process_index() == 0: logger.append_pytree(row)
def test_hessian_free_optimizer(self): """Tests the Hessian-free optimizer.""" model_str = 'autoencoder' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' input_shape = (2, 2, 1) output_shape = (4, ) hps = copy.copy(model_hps) hps.update({ 'optimizer': 'hessian_free', 'opt_hparams': { 'weight_decay': 0.0, }, 'hid_sizes': [2], 'activation_function': ['id'], 'input_shape': input_shape, 'output_shape': output_shape }) model = model_cls(hps, {}, loss, metrics) inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]]) targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape))) batch = {'inputs': inputs, 'targets': targets} def forward_fn(variables, inputs): logits = model.flax_module.apply(variables, inputs, train=True) return logits def opt_cost(variables): return model.loss_fn(forward_fn(variables, inputs), targets) init_fn, update_fn = optimizers.get_optimizer(hps, model) params = { 'Dense_0': { 'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2., 2.]]), 'bias': jnp.array([0., 0.]) }, 'Dense_1': { 'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]), 'bias': jnp.array([0., 0., 0., 0.]) } } variables = {'params': params} grad_fn = jax.grad(opt_cost) grads = grad_fn(variables)['params'] outputs = forward_fn(variables, batch['inputs']) n = inputs.shape[0] m = outputs.shape[-1] d = ravel_pytree(params)[0].shape[0] v = np.ones(d) state = init_fn(params) partial_forward_fn = partial(forward_fn, inputs=batch['inputs']) partial_loss_fn = partial(model.loss_fn, targets=batch['targets']) matmul_fn = partial(gvp, variables, outputs, state.inner_state.damping, partial_forward_fn, partial_loss_fn) jacobian = jax.jacfwd(partial_forward_fn)(variables)['params'] jacobian_tensor = np.concatenate( (jacobian['Dense_0']['bias'].reshape( n, m, -1), jacobian['Dense_0']['kernel'].reshape( n, m, -1), jacobian['Dense_1']['bias'].reshape(n, m, -1), jacobian['Dense_1']['kernel'].reshape(n, m, -1)), axis=2) ggn_matrix = 0 for i in range(n): jacobian_matrix = jacobian_tensor[i] hessian = jax.hessian(partial_loss_fn)(outputs[i, None])[0, :, 0, :] ggn_matrix += np.transpose( jacobian_matrix) @ hessian @ jacobian_matrix ggn_matrix /= n ggn_matrix += state.inner_state.damping * np.identity(d) expected = ggn_matrix @ v # Test the gvp function self.assertAlmostEqual(jnp.linalg.norm(matmul_fn(v) - expected), 0, places=4) update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(None, None, None, 0, None)) batch_shard = data_utils.shard(batch) state.hyperparams['learning_rate'] = 1.0 p, state = update_pmapped(grads, state, params, batch_shard, None) # Test the damping parameter update self.assertEqual(state.inner_state.damping, 3 / 2) # Test the search direction self.assertAlmostEqual(jnp.linalg.norm( ravel_pytree(p)[0] + jnp.linalg.inv(ggn_matrix) @ ravel_pytree(grads)[0]), 0, places=4)
def test_adam(self): """Test Adam preconditioning.""" lr = 1e-3 beta1 = 0.9 beta2 = 0.999 epsilon = 1e-7 opt_hparams = FrozenConfigDict({ 'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon }) hparams = FrozenConfigDict({ 'optimizer': 'adam', 'opt_hparams': opt_hparams, 'l2_decay_factor': 0.0, 'batch_size': 50, 'total_accumulated_batch_size': 50, }) init_fn, update_fn = optimizers.get_optimizer(hparams) params = {'foo': 1.0, 'bar': {'baz': 3.0}} gradients = [{ 'foo': 0.5, 'bar': { 'baz': 0.1 } }, { 'foo': 0.2, 'bar': { 'baz': 0.6 } }] optimizer_state = init_fn(params) optimizer_state.base_state.hyperparams['learning_rate'] = lr for gradient in gradients: updates, optimizer_state = update_fn(gradient, optimizer_state, params) params = optax.apply_updates(params, updates) # yes bias correction expected_preconditioner = _calculate_adam_preconditioner( gradients, beta2, epsilon, bias_correct=True) preconditioner = make_diag_preconditioner( 'adam', opt_hparams, optimizer_state, FrozenConfigDict(dict(bias_correction=True))) self.assertTrue( pytree_allclose(expected_preconditioner, preconditioner)) # no bias correction expected_preconditioner = _calculate_adam_preconditioner( gradients, beta2, epsilon, bias_correct=False) preconditioner = make_diag_preconditioner( 'adam', opt_hparams, optimizer_state, FrozenConfigDict(dict(bias_correction=False))) self.assertTrue( pytree_allclose(expected_preconditioner, preconditioner))