def test_optstate_sumsq(self): """Test that optstate sumsq and sumsq are computed correctly.""" init_fn, update_fn, _ = make_training_metrics( self.num_train_steps, optstate_sumsq_fields=['nu'], optstate_sum_fields=['nu']) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal(initial_metrics_state['optstate_sumsq'], {'nu': jnp.zeros(self.num_train_steps)})) self.assertTrue( pytree_equal(initial_metrics_state['optstate_sum'], {'nu': jnp.zeros(self.num_train_steps)})) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][0], total_tree_norm_sql2(self.mock_nu0)) self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][1], total_tree_norm_sql2(self.mock_nu1)) self.assertEqual(updated_metrics_state['optstate_sum']['nu'][0], total_tree_sum(self.mock_nu0)) self.assertEqual(updated_metrics_state['optstate_sum']['nu'][1], total_tree_sum(self.mock_nu1))
def test_update_param_norms(self): """Ensure that we update param norms correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps, enable_param_norms=True) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['param_norms'], { 'foo': jnp.array([ jnp.linalg.norm(self.mock_params0['foo']), jnp.linalg.norm(self.mock_params1['foo']), 0.0, 0.0, 0.0 ]), 'bar': { 'baz': jnp.array([ jnp.linalg.norm(self.mock_params0['bar']['baz']), jnp.linalg.norm(self.mock_params1['bar']['baz']), 0.0, 0.0, 0.0 ]) } }))
def test_summarize(self): """Test the training metrics summarizer.""" _, _, summarize_fn = make_training_metrics(self.num_train_steps, enable_train_cost=True, enable_ema=True) metrics_state = { 'train_cost': jnp.array([1.0, 0.5, 0.25, 0.0, 0.0]), 'param_norm': { 'foo': 7.0, 'bar': { 'baz': 2.0 } }, 'grad_ema': { 'foo': 1 * jnp.ones(5), 'bar': { 'baz': 2 * jnp.ones(10) } }, 'grad_sq_ema': { 'foo': 2 * jnp.ones(5), 'bar': { 'baz': 6 * jnp.ones(10) } }, 'update_ema': { 'foo': 2 * jnp.ones(5), 'bar': { 'baz': 1 * jnp.ones(10) } }, 'update_sq_ema': { 'foo': 6 * jnp.ones(5), 'bar': { 'baz': 2 * jnp.ones(10) } }, } tree_summary = summarize_fn(metrics_state) self.assertTrue( pytree_equal( tree_summary, { 'param_norm': { '/foo': 7.0, '/bar/baz': 2.0 }, 'grad_var': { '/foo': 5 * (2 - 1**2), '/bar/baz': 10 * (6 - 2**2) }, 'update_var': { '/foo': 5 * (6 - 2**2), '/bar/baz': 10 * (2 - 1**2) }, 'update_ratio': { '/foo': 5 * (6 - 2**2) / 7.0, '/bar/baz': 10 * (2 - 1**2) / 2.0 } }))
def test_init(self): """Test the training metrics initializer.""" zeros_like_params = jax.tree_map(jnp.zeros_like, self.mock_params0) zeros_scalar_like_params = jax.tree_map(lambda x: 0.0, self.mock_params0) zeros_timeseries = jnp.zeros(self.num_train_steps) zeros_timeseries_like_params = jax.tree_map( lambda x: jnp.zeros(self.num_train_steps), self.mock_params0) # Test init with everything disabled. init_fn, _, _ = make_training_metrics(self.num_train_steps) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal({'param_norm': zeros_scalar_like_params}, initial_metrics_state)) # Test init with enable_ema = True and enable_train_cost=True. init_fn, _, _ = make_training_metrics(self.num_train_steps, enable_ema=True, enable_train_cost=True, enable_param_norms=True, enable_gradient_norm=True, enable_update_norm=True, enable_update_norms=True) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal( initial_metrics_state, { 'train_cost': zeros_timeseries, 'param_norm': zeros_scalar_like_params, 'grad_ema': zeros_like_params, 'grad_sq_ema': zeros_like_params, 'update_ema': zeros_like_params, 'update_sq_ema': zeros_like_params, 'param_norms': zeros_timeseries_like_params, 'gradient_norm': zeros_timeseries, 'update_norm': zeros_timeseries, 'update_norms': zeros_timeseries_like_params }))
def test_update_update_norms(self): """Ensure that we update gradient and update norms correctly.""" init_fn, update_fn, _ = make_training_metrics( self.num_train_steps, enable_gradient_norm=True, enable_update_norm=True, enable_update_norms=True) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['update_norms'], { 'foo': jnp.array([ self.step_size * jnp.linalg.norm(self.mock_grad1['foo']), self.step_size * jnp.linalg.norm(self.mock_grad2['foo']), 0.0, 0.0, 0.0 ]), 'bar': { 'baz': jnp.array([ self.step_size * jnp.linalg.norm(self.mock_grad1['bar']['baz']), self.step_size * jnp.linalg.norm(self.mock_grad2['bar']['baz']), 0.0, 0.0, 0.0 ]) } })) self.assertEqual(updated_metrics_state['update_norm'][0], total_tree_norm_l2(self.mock_grad1)) self.assertEqual(updated_metrics_state['update_norm'][1], total_tree_norm_l2(self.mock_grad2)) self.assertEqual(updated_metrics_state['update_norm'][0], self.step_size * total_tree_norm_l2(self.mock_grad1)) self.assertEqual(updated_metrics_state['update_norm'][1], self.step_size * total_tree_norm_l2(self.mock_grad2))
def test_train_cost(self): """Ensure that the train cost is logged correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps, enable_train_cost=True) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['train_cost'], jnp.array([self.mock_cost0, self.mock_cost1, 0.0, 0.0, 0.0])))
def test_update_grad_ema(self): """Ensure that the training metrics updater updates grad ema correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps, enable_ema=True, ema_beta=0.5) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['grad_ema'], jax.tree_map(lambda x, y, z: 0.25 * x + 0.25 * y + 0.5 * z, self.mock_zeros, self.mock_grad1, self.mock_grad2)))
def train(train_dir, model, dataset_builder, initializer, num_train_steps, hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, early_stopping_target_name=None, early_stopping_target_value=None, early_stopping_mode=None, eval_steps=None, metrics_logger=None, init_logger=None, training_metrics_config=None, callback_configs=None, external_checkpoint_path=None): """Main training loop. Trains the given network on the specified dataset for the given number of epochs. Saves the training curve in train_dir/r=3/results.tsv. Args: train_dir: (str) Path of the training directory. model: (BaseModel) Model object to be trained. dataset_builder: dataset builder returned by datasets.get_dataset. initializer: Must have API as defined in initializers.py num_train_steps: (int) Number of steps to train on. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_batch_size: the evaluation batch size. If None, use hps.batch_size. eval_num_batches: (int) The number of batches used for evaluating on validation and test sets. Set to None to evaluate on the whole train set. eval_train_num_batches: (int) The number of batches for evaluating on train. Set to None to evaluate on the whole training set. eval_frequency: (int) Evaluate every k steps. checkpoint_steps: List of integers indicating special steps to save checkpoints at. These checkpoints do not get used for preemption recovery. early_stopping_target_name: A string naming the metric to use to perform early stopping. If this metric reaches the value `early_stopping_target_value`, training will stop. Must include the dataset split (ex: validation/error_rate). early_stopping_target_value: A float indicating the value at which to stop training. early_stopping_mode: One of "above" or "below", indicates if we should stop when the metric is above or below the threshold value. Example: if "above", then training will stop when `report[early_stopping_target_name] >= early_stopping_target_value`. eval_steps: List of integers indicating which steps to perform evals. If provided, eval_frequency will be ignored. Performing an eval implies saving a checkpoint that will be used to resume training in the case of preemption. metrics_logger: Used to log all eval metrics during training. See utils.MetricLogger for API definition. init_logger: Used for black box initializers that have learning curves. training_metrics_config: Dict specifying the configuration of the training_metrics_grabber. Set to None to skip logging of advanced training metrics. callback_configs: List of configs specifying general callbacks to run during the eval phase. Empty list means no callbacks are run. See callbacks.py for details on what is expected in a config. external_checkpoint_path: (str) If this argument is set, we will load the optimizer_state, params, batch_stats, and training_metrics from the checkpoint at this location. Yields: metrics: A dictionary of all eval metrics from the given epoch. """ # NOTE: the initialization RNG should *not* be per-host, as this will create # different sets of weights per host. However, all other RNGs should be # per-host. # TODO(znado,gilmer,gdahl): implement replicating the same initialization # across hosts. rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.process_index()) rng, data_rng = jax.random.split(rng) # only used if checkpoints_steps is non-empty. checkpoint_dir = os.path.join(train_dir, 'checkpoints') # For logging / processing off the main thread pool = multiprocessing.pool.ThreadPool() if jax.process_index() == 0: logging.info('Let the training begin!') logging.info('Dataset input shape: %r', hps.input_shape) logging.info('Hyperparameters: %s', hps) if eval_batch_size is None: eval_batch_size = hps.batch_size if callback_configs is None: callback_configs = [] # Maybe run the initializer. unreplicated_params, unreplicated_batch_stats = init_utils.initialize( model.flax_module, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, init_logger, model.get_fake_batch(hps)) if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) logging.info('train_size: %d,', hps.train_size) # Note that global_step refers to the number of gradients calculations, not # the number of model updates. This means when using gradient accumulation, # one must supply configs where the number of steps are in units of gradient # calculations, not model updates, and in post processing one must divide # global_step by grad_accum_step_multiplier to get the number of updates. # # If using gradient accumulation, stretch the learning rate schedule by the # number of gradient calculations per weight update. stretch_factor = 1 if hps.get('total_accumulated_batch_size') is not None: stretch_factor = hps.total_accumulated_batch_size // hps.batch_size lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps, stretch_factor=stretch_factor) optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer( hps, model) unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params) (unreplicated_metrics_state, metrics_update_fn, metrics_summary_fn) = None, None, None if training_metrics_config is not None: (metrics_init_fn, metrics_update_fn, metrics_summary_fn) = make_training_metrics(num_train_steps, **training_metrics_config) unreplicated_metrics_state = metrics_init_fn(unreplicated_params) (optimizer_state, params, batch_stats, metrics_state, global_step, sum_train_cost, preemption_count, is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint( unreplicated_optimizer_state, unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, train_dir=train_dir, external_checkpoint_path=external_checkpoint_path) if is_restored: preemption_count += 1 # Fold the restored step into the dataset RNG so that we will get a # different shuffle each time we restore, so that we do not repeat a # previous dataset ordering again after restoring. This is not the only # difference in shuffling each pre-emption, because we often times reshuffle # the input files each time in a non-deterministic manner. # # Note that if we are pre-empted more than once per epoch then we will # retrain more on the beginning of the training split, because each time we # restore we refill the shuffle buffer with the first `shuffle_buffer_size` # elements from the training split to continue training. # # Also note that for evaluating on the training split, because we are # reshuffling each time, we will get a new eval_train split each time we are # pre-empted. data_rng = jax.random.fold_in(data_rng, global_step) assert hps.batch_size % (jax.device_count()) == 0 assert eval_batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=eval_batch_size, hps=hps, ) update_fn = functools.partial(update, training_cost=model.training_cost, grad_clip=hps.get('grad_clip'), optimizer_update_fn=optimizer_update_fn, metrics_update_fn=metrics_update_fn) # in_axes = ( # optimizer_state = 0, # params = 0, # batch_stats = 0, # metrics_state = 0, # batch = 0, # step = None, # lr = None, # rng = None, # local_device_index = 0, # running_train_cost = 0, # training_cost, # grad_clip, # optimizer_update_fn, # metrics_state_update_fn) # Also, we can donate buffers for 'optimizer', 'batch_stats', # 'batch' and 'training_metrics_state' for update's pmapped computation. update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0), donate_argnums=(0, 1, 2, 8)) # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch', donate_argnums=(2, )) start_time = time.time() start_step = global_step prev_eval_step = start_step def get_step_frequency(cur_step): return float(cur_step - start_step) / (time.time() - start_time) if jax.process_index() == 0: trainer_utils.log_message('Starting training!', pool, xm_work_unit) # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. local_device_indices = np.arange(jax.local_device_count()) # Start at the resumed step and continue until we have finished the number of # training steps. If building a dataset iterator using a tf.data.Dataset, in # the case of a batch size that does not evenly divide the training dataset # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset # then the final batch in this iterator will be a partial batch. However, if # `drop_remainer=False`, then this iterator will always return batches of the # same size, and the final batch will have elements from the start of the # (num_epochs + 1)-th epoch. train_iter = itertools.islice(dataset.train_iterator_fn(), global_step, num_train_steps) eval_callbacks = [] rng, callback_rng = jax.random.split(rng) callback_rngs = jax.random.split(callback_rng, len(callback_configs)) for callback_rng, config in zip(callback_rngs, callback_configs): eval_callback = callbacks.get_callback(config['callback_name'])( model, params, batch_stats, optimizer_state, dataset, hps, config, train_dir, callback_rng) eval_callbacks.append(eval_callback) train_iter = trainer_utils.prefetch_input_pipeline( train_iter, hps.num_device_prefetches) eval_start_time = start_time eval_start_step = start_step for _ in range(start_step, num_train_steps): with jax.profiler.StepTraceAnnotation('train', step_num=global_step): # NOTE(dsuo): to properly profile each step, we must include batch # creation in the StepTraceContext (as opposed to putting `train_iter` # directly in the top-level for loop). batch = next(train_iter) if global_step in checkpoint_steps and jax.process_index() == 0: checkpoint.save_unreplicated_checkpoint_background( checkpoint_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost, max_to_keep=None) lr = lr_fn(global_step) optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped( optimizer_state, params, batch_stats, metrics_state, batch, global_step, lr, rng, local_device_indices, sum_train_cost) global_step += 1 # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment global_step, having 0 in # eval_steps does nothing. if trainer_utils.should_eval(global_step, eval_frequency, eval_steps): train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) eval_start_step = global_step eval_start_time = time.time() batch_stats = trainer_utils.maybe_sync_batchnorm_stats( batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update( learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) for eval_callback in eval_callbacks: callback_metrics = eval_callback.run_eval( params, batch_stats, optimizer_state, global_step) if set(callback_metrics.keys()).intersection( set(report.keys())): raise ValueError( 'There was a collision between the callback' 'metrics and the standard eval metrics keys') report.update(callback_metrics) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics( metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) sum_train_cost = jnp.zeros(jax.local_device_count()) prev_eval_step = global_step early_stopping_condition = trainer_utils.check_for_early_stopping( early_stopping_target_name, early_stopping_target_value, early_stopping_mode, report) if early_stopping_condition: comparison_string = '>=' if early_stopping_mode == 'above' else '<=' logging.info( 'Early stopping because metric %s=%f, reached the target value ' 'of %s %f.', early_stopping_target_name, report[early_stopping_target_name], comparison_string, early_stopping_target_value) return # Always log and checkpoint on host 0 at the end of training. # If we moved where in the loop body evals happen then we would not need this # test. if prev_eval_step != num_train_steps: train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) lr = lr_fn(global_step) # Correct the average for the final partial epoch. mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update(learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics(metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) # To make sure the last checkpoint was correctly saved. checkpoint.wait_for_checkpoint_save()