def test_initializers(self, init): """Test that each initializer runs, and the output is a valid pytree.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, model_hps = _load_model( 'fully_connected') _, init_rng = jax.random.split(rng) initializer = initializers.get_initializer(init) init_hps = initializers.get_initializer_hparams(init) init_hps.update(model_hps) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) new_params = initializer(loss_fn=loss_fn, flax_module=flax_module, params=params, hps=init_hps, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, rng_key=init_rng) # Check new params are still valid params outputs = flax_module.apply({'params': new_params}, jnp.ones(input_shape), train=True) utils.log_pytree_shape_and_statistics(new_params) self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1]))
def test_initialize_rescale(self): """Test rescaling a single layer of a model.""" input_shape = (28, 28, 1) output_shape = (10,) model_str = 'fully_connected' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': output_shape}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss_name, metrics_name) initializer = initializers.get_initializer('noop') rng, init_rng = jax.random.split(rng) # First initialize with no rescale. flax_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) utils.log_pytree_shape_and_statistics(flax_module.params) # Now rescale a layer by 100. rescale_factor = 100 hps.layer_rescale_factors = { '/Dense_1/kernel': rescale_factor, } rescaled_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) # Check the right variable is rescaled v1 = flax_module.params['Dense_1']['kernel'] v2 = rescaled_module.params['Dense_1']['kernel'] diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0) # Check that other variables are the same v1 = flax_module.params['Dense_2']['kernel'] v2 = rescaled_module.params['Dense_2']['kernel'] diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0)
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] _, flax_module = model.flax_module_def.create_by_shape(rng, [input_shape], train=True) utils.log_pytree_shape_and_statistics(flax_module.params) return flax_module, input_shape
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=True)) init_dict = model_init_fn({'params': rng}, jnp.zeros(input_shape)) # Trainable model parameters. params = init_dict['params'] utils.log_pytree_shape_and_statistics(params) return model.flax_module, params, input_shape, hps
def train(train_dir, model, dataset_builder, initializer, num_train_steps, hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, early_stopping_target_name=None, early_stopping_target_value=None, early_stopping_mode=None, eval_steps=None, metrics_logger=None, init_logger=None, training_metrics_config=None, callback_configs=None, external_checkpoint_path=None): """Main training loop. Trains the given network on the specified dataset for the given number of epochs. Saves the training curve in train_dir/r=3/results.tsv. Args: train_dir: (str) Path of the training directory. model: (BaseModel) Model object to be trained. dataset_builder: dataset builder returned by datasets.get_dataset. initializer: Must have API as defined in initializers.py num_train_steps: (int) Number of steps to train on. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_batch_size: the evaluation batch size. If None, use hps.batch_size. eval_num_batches: (int) The number of batches used for evaluating on validation and test sets. Set to None to evaluate on the whole train set. eval_train_num_batches: (int) The number of batches for evaluating on train. Set to None to evaluate on the whole training set. eval_frequency: (int) Evaluate every k steps. checkpoint_steps: List of integers indicating special steps to save checkpoints at. These checkpoints do not get used for preemption recovery. early_stopping_target_name: A string naming the metric to use to perform early stopping. If this metric reaches the value `early_stopping_target_value`, training will stop. Must include the dataset split (ex: validation/error_rate). early_stopping_target_value: A float indicating the value at which to stop training. early_stopping_mode: One of "above" or "below", indicates if we should stop when the metric is above or below the threshold value. Example: if "above", then training will stop when `report[early_stopping_target_name] >= early_stopping_target_value`. eval_steps: List of integers indicating which steps to perform evals. If provided, eval_frequency will be ignored. Performing an eval implies saving a checkpoint that will be used to resume training in the case of preemption. metrics_logger: Used to log all eval metrics during training. See utils.MetricLogger for API definition. init_logger: Used for black box initializers that have learning curves. training_metrics_config: Dict specifying the configuration of the training_metrics_grabber. Set to None to skip logging of advanced training metrics. callback_configs: List of configs specifying general callbacks to run during the eval phase. Empty list means no callbacks are run. See callbacks.py for details on what is expected in a config. external_checkpoint_path: (str) If this argument is set, we will load the optimizer_state, params, batch_stats, and training_metrics from the checkpoint at this location. Yields: metrics: A dictionary of all eval metrics from the given epoch. """ # NOTE: the initialization RNG should *not* be per-host, as this will create # different sets of weights per host. However, all other RNGs should be # per-host. # TODO(znado,gilmer,gdahl): implement replicating the same initialization # across hosts. rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.process_index()) rng, data_rng = jax.random.split(rng) # only used if checkpoints_steps is non-empty. checkpoint_dir = os.path.join(train_dir, 'checkpoints') # For logging / processing off the main thread pool = multiprocessing.pool.ThreadPool() if jax.process_index() == 0: logging.info('Let the training begin!') logging.info('Dataset input shape: %r', hps.input_shape) logging.info('Hyperparameters: %s', hps) if eval_batch_size is None: eval_batch_size = hps.batch_size if callback_configs is None: callback_configs = [] # Maybe run the initializer. unreplicated_params, unreplicated_batch_stats = init_utils.initialize( model.flax_module, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, init_logger, model.get_fake_batch(hps)) if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) logging.info('train_size: %d,', hps.train_size) # Note that global_step refers to the number of gradients calculations, not # the number of model updates. This means when using gradient accumulation, # one must supply configs where the number of steps are in units of gradient # calculations, not model updates, and in post processing one must divide # global_step by grad_accum_step_multiplier to get the number of updates. # # If using gradient accumulation, stretch the learning rate schedule by the # number of gradient calculations per weight update. stretch_factor = 1 if hps.get('total_accumulated_batch_size') is not None: stretch_factor = hps.total_accumulated_batch_size // hps.batch_size lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps, stretch_factor=stretch_factor) optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer( hps, model) unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params) (unreplicated_metrics_state, metrics_update_fn, metrics_summary_fn) = None, None, None if training_metrics_config is not None: (metrics_init_fn, metrics_update_fn, metrics_summary_fn) = make_training_metrics(num_train_steps, **training_metrics_config) unreplicated_metrics_state = metrics_init_fn(unreplicated_params) (optimizer_state, params, batch_stats, metrics_state, global_step, sum_train_cost, preemption_count, is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint( unreplicated_optimizer_state, unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, train_dir=train_dir, external_checkpoint_path=external_checkpoint_path) if is_restored: preemption_count += 1 # Fold the restored step into the dataset RNG so that we will get a # different shuffle each time we restore, so that we do not repeat a # previous dataset ordering again after restoring. This is not the only # difference in shuffling each pre-emption, because we often times reshuffle # the input files each time in a non-deterministic manner. # # Note that if we are pre-empted more than once per epoch then we will # retrain more on the beginning of the training split, because each time we # restore we refill the shuffle buffer with the first `shuffle_buffer_size` # elements from the training split to continue training. # # Also note that for evaluating on the training split, because we are # reshuffling each time, we will get a new eval_train split each time we are # pre-empted. data_rng = jax.random.fold_in(data_rng, global_step) assert hps.batch_size % (jax.device_count()) == 0 assert eval_batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=eval_batch_size, hps=hps, ) update_fn = functools.partial(update, training_cost=model.training_cost, grad_clip=hps.get('grad_clip'), optimizer_update_fn=optimizer_update_fn, metrics_update_fn=metrics_update_fn) # in_axes = ( # optimizer_state = 0, # params = 0, # batch_stats = 0, # metrics_state = 0, # batch = 0, # step = None, # lr = None, # rng = None, # local_device_index = 0, # running_train_cost = 0, # training_cost, # grad_clip, # optimizer_update_fn, # metrics_state_update_fn) # Also, we can donate buffers for 'optimizer', 'batch_stats', # 'batch' and 'training_metrics_state' for update's pmapped computation. update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0), donate_argnums=(0, 1, 2, 8)) # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch', donate_argnums=(2, )) start_time = time.time() start_step = global_step prev_eval_step = start_step def get_step_frequency(cur_step): return float(cur_step - start_step) / (time.time() - start_time) if jax.process_index() == 0: trainer_utils.log_message('Starting training!', pool, xm_work_unit) # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. local_device_indices = np.arange(jax.local_device_count()) # Start at the resumed step and continue until we have finished the number of # training steps. If building a dataset iterator using a tf.data.Dataset, in # the case of a batch size that does not evenly divide the training dataset # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset # then the final batch in this iterator will be a partial batch. However, if # `drop_remainer=False`, then this iterator will always return batches of the # same size, and the final batch will have elements from the start of the # (num_epochs + 1)-th epoch. train_iter = itertools.islice(dataset.train_iterator_fn(), global_step, num_train_steps) eval_callbacks = [] rng, callback_rng = jax.random.split(rng) callback_rngs = jax.random.split(callback_rng, len(callback_configs)) for callback_rng, config in zip(callback_rngs, callback_configs): eval_callback = callbacks.get_callback(config['callback_name'])( model, params, batch_stats, optimizer_state, dataset, hps, config, train_dir, callback_rng) eval_callbacks.append(eval_callback) train_iter = trainer_utils.prefetch_input_pipeline( train_iter, hps.num_device_prefetches) eval_start_time = start_time eval_start_step = start_step for _ in range(start_step, num_train_steps): with jax.profiler.StepTraceAnnotation('train', step_num=global_step): # NOTE(dsuo): to properly profile each step, we must include batch # creation in the StepTraceContext (as opposed to putting `train_iter` # directly in the top-level for loop). batch = next(train_iter) if global_step in checkpoint_steps and jax.process_index() == 0: checkpoint.save_unreplicated_checkpoint_background( checkpoint_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost, max_to_keep=None) lr = lr_fn(global_step) optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped( optimizer_state, params, batch_stats, metrics_state, batch, global_step, lr, rng, local_device_indices, sum_train_cost) global_step += 1 # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment global_step, having 0 in # eval_steps does nothing. if trainer_utils.should_eval(global_step, eval_frequency, eval_steps): train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) eval_start_step = global_step eval_start_time = time.time() batch_stats = trainer_utils.maybe_sync_batchnorm_stats( batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update( learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) for eval_callback in eval_callbacks: callback_metrics = eval_callback.run_eval( params, batch_stats, optimizer_state, global_step) if set(callback_metrics.keys()).intersection( set(report.keys())): raise ValueError( 'There was a collision between the callback' 'metrics and the standard eval metrics keys') report.update(callback_metrics) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics( metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) sum_train_cost = jnp.zeros(jax.local_device_count()) prev_eval_step = global_step early_stopping_condition = trainer_utils.check_for_early_stopping( early_stopping_target_name, early_stopping_target_value, early_stopping_mode, report) if early_stopping_condition: comparison_string = '>=' if early_stopping_mode == 'above' else '<=' logging.info( 'Early stopping because metric %s=%f, reached the target value ' 'of %s %f.', early_stopping_target_name, report[early_stopping_target_name], comparison_string, early_stopping_target_value) return # Always log and checkpoint on host 0 at the end of training. # If we moved where in the loop body evals happen then we would not need this # test. if prev_eval_step != num_train_steps: train_steps_per_sec = (global_step - eval_start_step) / ( time.time() - eval_start_time) batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats) report, eval_time = eval_metrics(params, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) lr = lr_fn(global_step) # Correct the average for the final partial epoch. mean_train_cost = sum_train_cost.mean().item() / max( 1, global_step - prev_eval_step) report.update(learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, train_steps_per_sec=train_steps_per_sec, overall_steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, grad_norm=np.mean(grad_norm), preemption_count=preemption_count, train_cost=mean_train_cost) yield report if jax.process_index() == 0: trainer_utils.log_eta(pool, xm_work_unit, global_step, train_steps_per_sec, num_train_steps, start_time, eval_frequency, eval_steps, eval_time) trainer_utils.log_epoch_report(report, metrics_logger) trainer_utils.maybe_log_training_metrics(metrics_state, metrics_summary_fn, metrics_logger) checkpoint.save_unreplicated_checkpoint_background( train_dir, optimizer_state, params, batch_stats, metrics_state, global_step, preemption_count, sum_train_cost) # To make sure the last checkpoint was correctly saved. checkpoint.wait_for_checkpoint_save()
def train(train_dir, model, dataset_builder, initializer, num_train_steps, hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, eval_steps=None, metrics_logger=None, init_logger=None, training_metrics_config=None, use_deprecated_checkpointing=True): """Main training loop. Trains the given network on the specified dataset for the given number of epochs. Saves the training curve in train_dir/r=3/results.tsv. Args: train_dir: (str) Path of the training directory. model: (BaseModel) Model object to be trained. dataset_builder: dataset builder returned by datasets.get_dataset. initializer: Must have API as defined in initializers.py num_train_steps: (int) Number of steps to train on. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_batch_size: the evaluation batch size. If None, use hps.batch_size. eval_num_batches: (int) The number of batches used for evaluating on validation and test sets. Set to None to evaluate on the whole train set. eval_train_num_batches: (int) The number of batches for evaluating on train. Set to None to evaluate on the whole training set. eval_frequency: (int) Evaluate every k steps. checkpoint_steps: List of integers indicating special steps to save checkpoints at. These checkpoints do not get used for preemption recovery. eval_steps: List of integers indicating which steps to perform evals. If provided, eval_frequency will be ignored. Performing an eval implies saving a checkpoint that will be used to resume training in the case of preemption. metrics_logger: Used to log all eval metrics during training. See utils.MetricLogger for API definition. init_logger: Used for black box initializers that have learning curves. training_metrics_config: Dict specifying the configuration of the training_metrics_grabber. Set to None to skip logging of advanced training metrics. use_deprecated_checkpointing: Whether to use deprecated checkpointing. Yields: metrics: A dictionary of all eval metrics from the given epoch. """ # NOTE: the initialization RNG should *not* be per-host, as this will create # different sets of weights per host. However, all other RNGs should be # per-host. # TODO(znado,gilmer,gdahl): implement replicating the same initialization # across hosts. rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.host_id()) rng, data_rng = jax.random.split(rng) # only used if checkpoints_steps is non-empty. checkpoint_dir = os.path.join(train_dir, 'checkpoints') if jax.host_id() == 0: logging.info('Let the training begin!') logging.info('Dataset input shape: %r', hps.input_shape) logging.info('Hyperparameters: %s', hps) if eval_batch_size is None: eval_batch_size = hps.batch_size # Maybe run the initializer. flax_module, batch_stats = initialize(model.flax_module_def, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, init_logger) if jax.host_id() == 0: utils.log_pytree_shape_and_statistics(flax_module.params) logging.info('train_size: %d,', hps.train_size) lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps) optimizer = get_optimizer(hps).create(flax_module) training_metrics_grabber = None if training_metrics_config: training_metrics_grabber = utils.TrainingMetricsGrabber.create( optimizer.target.params, training_metrics_config) (optimizer, batch_stats, training_metrics_grabber, global_step, sum_train_cost, preemption_count, is_restored) = _maybe_restore_latest_checkpoint( unreplicated_optimizer=optimizer, unreplicated_batch_stats=batch_stats, unreplicated_training_metrics_grabber=training_metrics_grabber, train_dir=train_dir, use_deprecated_checkpointing=use_deprecated_checkpointing) if is_restored: preemption_count += 1 # Fold the restored step into the dataset RNG so that we will get a # different shuffle each time we restore, so that we do not repeat a # previous dataset ordering again after restoring. This is not the only # difference in shuffling each pre-emption, because we often times reshuffle # the input files each time in a non-deterministic manner. # # Note that if we are pre-empted more than once per epoch then we will # retrain more on the beginning of the training split, because each time we # restore we refill the shuffle buffer with the first `shuffle_buffer_size` # elements from the training split to continue training. # # Also note that for evaluating on the training split, because we are # reshuffling each time, we will get a new eval_train split each time we are # pre-empted. data_rng = jax.random.fold_in(data_rng, global_step) assert hps.batch_size % (jax.device_count()) == 0 assert eval_batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=eval_batch_size, hps=hps, ) # pmap functions for the training loop # in_axes = (optimizer = 0, batch_stats = 0, batch = 0, step = None, # lr = None, rng = None, local_device_index = 0, training_metrics_grabber = 0, # training_metrics_grabber, training_cost ) update_pmapped = jax.pmap(functools.partial( update, training_cost=model.training_cost), axis_name='batch', in_axes=(0, 0, 0, None, None, None, 0, 0)) evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch') start_time = time.time() start_step = global_step prev_eval_step = start_step def get_step_frequency(cur_step): return float(cur_step - start_step) / (time.time() - start_time) if jax.host_id() == 0: logging.info('Starting training!') # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. local_device_indices = np.arange(jax.local_device_count()) # Start at the resumed step and continue until we have finished the number of # training steps. If building a dataset iterator using a tf.data.Dataset, in # the case of a batch size that does not evenly divide the training dataset # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset # then the final batch in this iterator will be a partial batch. However, if # `drop_remainer=False`, then this iterator will always return batches of the # same size, and the final batch will have elements from the start of the # (num_epochs + 1)-th epoch. train_iter = itertools.islice(dataset.train_iterator_fn(), global_step, num_train_steps) for batch in train_iter: if global_step in checkpoint_steps and jax.host_id() == 0: save_checkpoint( checkpoint_dir, { 'optimizer': optimizer, 'batch_stats': batch_stats, 'training_metrics_grabber': training_metrics_grabber }, global_step, preemption_count, sum_train_cost, max_to_keep=None, use_deprecated_checkpointing=use_deprecated_checkpointing) batch = data_utils.shard(batch) lr = lr_fn(global_step) optimizer, batch_stats, cost_val, training_metrics_grabber = update_pmapped( optimizer, batch_stats, batch, global_step, lr, rng, local_device_indices, training_metrics_grabber) # Calling float is needed since cost_val is a shape (1,) DeviceArray. sum_train_cost += float(jnp.mean(cost_val)) global_step += 1 # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment global_step, having 0 in # eval_steps does nothing. if should_eval(global_step, eval_frequency, eval_steps): batch_stats = _maybe_sync_batchnorm_stats(batch_stats) report, eval_time = eval_metrics(optimizer.target, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) mean_train_cost = sum_train_cost / max( 1, global_step - prev_eval_step) report.update(learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, preemption_count=preemption_count, train_cost=mean_train_cost) yield report if jax.host_id() == 0: _log_epoch_report(report, metrics_logger) _maybe_log_training_metrics(training_metrics_grabber, metrics_logger) save_checkpoint( train_dir, { 'optimizer': optimizer, 'batch_stats': batch_stats, 'training_metrics_grabber': training_metrics_grabber }, global_step, preemption_count, sum_train_cost, use_deprecated_checkpointing=use_deprecated_checkpointing) sum_train_cost = 0.0 prev_eval_step = global_step # Always log and checkpoint on host 0 at the end of training. # If we moved where in the loop body evals happen then we would not need this # test. if prev_eval_step != num_train_steps: batch_stats = _maybe_sync_batchnorm_stats(batch_stats) report, eval_time = eval_metrics(optimizer.target, batch_stats, dataset, eval_num_batches, eval_train_num_batches, evaluate_batch_pmapped) lr = lr_fn(global_step) # Correct the average for the final partial epoch. mean_train_cost = sum_train_cost / max(1, global_step - prev_eval_step) report.update(learning_rate=float(lr), global_step=global_step, epoch=global_step * hps.batch_size // hps.train_size, steps_per_sec=get_step_frequency(global_step), eval_time=eval_time, preemption_count=preemption_count, train_cost=mean_train_cost) yield report if jax.host_id() == 0: _log_epoch_report(report, metrics_logger) _maybe_log_training_metrics(training_metrics_grabber, metrics_logger) save_checkpoint( train_dir, { 'optimizer': optimizer, 'batch_stats': batch_stats, 'training_metrics_grabber': training_metrics_grabber }, global_step, preemption_count, sum_train_cost, use_deprecated_checkpointing=use_deprecated_checkpointing) # To make sure the last checkpoint was correctly saved. checkpoint.wait_for_checkpoint_save()