def decode_pmap_model( task_p: InstantiableParams, input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], continuous_decode: bool, ) -> None: """Runs the decoding on the entire decoder datasets for a PMAP model. Args: task_p: Params of the task encapsulating a the data parallel model. input_p: List of input params to be decoded. job_log_dir: Directory for the job logs. restore_checkpoint_dir: The directory from which to restore checkpoint. If None, uses job_log_dir. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. continuous_decode: whether to continuously decode on the latest ckpt. """ if continuous_decode and restore_checkpoint_step is not None: raise ValueError('Continuous decoding mode requires restore_checkpoint_step' '=None, actual restore_checkpoint_step=' f'{restore_checkpoint_step}') restore_checkpoint_dir = restore_checkpoint_dir or os.path.join( job_log_dir, 'checkpoints') # TODO(shafey): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) # From now on, different replicas should use different random seeds. # Here, each process will have its unique prng_key. # prng_key will be further split so that each core on a host will get # different prng_key. prng_key = jax.random.fold_in(prng_key, jax.process_index()) logging.info('root prng_key: %s', prng_key) prng_key, eval_key = jax.random.split(prng_key) prng_seed = jax.random.split(eval_key, num=jax.local_device_count()) logging.info('decoder prng_seed: %s', prng_seed) inputs = [p.Instantiate() for p in input_p] summary_base_dir = os.path.join(job_log_dir, 'summaries') dirnames = _get_dir_names(input_p) summary_decode_dirs = [ os.path.join(summary_base_dir, f'decode_test_{dirnames[split]}') for split, _ in enumerate(input_p) ] with contextlib.ExitStack() as exit_stack: summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_decode_dirs ] jax_task = task_p.Instantiate() model_states = trainer_lib.initialize_model_state(jax_task, init_key) model_states = checkpoints.restore_checkpoint( model_states, restore_checkpoint_dir, step=restore_checkpoint_step) replicated_model_states = trainer_lib.replicate_model_state(model_states) logging.info('replicated_model_states: %s', jax.tree_map(lambda x: x.shape, replicated_model_states)) last_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) while True: _decode_once_pmap_model(jax_task, task_p, inputs, input_p, prng_seed, job_log_dir, replicated_model_states, summary_writers) if not continuous_decode: break if last_checkpoint is not None: last_ckpt_step = int(last_checkpoint.split('_')[-1]) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break # Release replicated_model_states. del replicated_model_states new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) while new_checkpoint == last_checkpoint: time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) logging.info('Found new checkpoint: %s', new_checkpoint) model_states = checkpoints.restore_checkpoint(model_states, restore_checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) last_checkpoint = new_checkpoint
def evaluate_pmap_model( task_p: InstantiableParams, eval_input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], ) -> None: """Runs the evaluation loop on the entire test dataset for PMAP model. Args: task_p: Params for the task encapsulating the data parallel model. eval_input_p: List of params for the eval data input pipelines. job_log_dir: Directory for the job logs. """ logging.info('Using pmap for data parallelism.') jax_task = task_p.Instantiate() eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p] # TODO(shafey): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = os.path.join(job_log_dir, 'checkpoints') model_states = trainer_lib.initialize_model_state(jax_task, init_key) # Pmap does not use GDA, and so global_mesh and mesh_axes are None. model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) logging.info('replicated_model_states: %s', jax.tree_map(lambda x: x.shape, replicated_model_states)) # From now on, different replicas should use different random seeds. # Here, each process will have its unique prng_key. # prng_key will be further split so that each core on a host will get # different prng_key. prng_key = jax.random.fold_in(prng_key, jax.process_index()) logging.info('root prng_key: %s', prng_key) def eval_step(mdl_vars, prng_key, global_step, inputs): return trainer_lib.eval_step_single_learner( jax_task, mdl_vars, prng_key, global_step, inputs, data_parallel_axis_name='batch', fprop_dtype=jax_task.model.fprop_dtype) num_devices = jax.local_device_count() prng_key, eval_key = jax.random.split(prng_key) eval_prng_seed = jax.random.split(eval_key, num=num_devices) logging.info('eval prng_seed: %s', eval_prng_seed) p_eval_step = jax.pmap(eval_step, axis_name='batch') logging.info('Evaluation loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_eval_dirs = [ os.path.join(summary_base_dir, f'eval_test_{split}') for split, _ in enumerate(eval_input_p) ] num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p ] last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) with contextlib.ExitStack() as exit_stack: eval_summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_eval_dirs ] while True: step_i = int(jax.device_get(replicated_model_states.step)[0]) eval_step = functools.partial(p_eval_step, maybe_ema(replicated_model_states), eval_prng_seed, replicated_model_states.step) # Run the eval loop. model_utils.run_eval_loop_over_test_splits( num_steps, eval_step, eval_summary_writers, step_i, eval_input_pipelines, reshard_inputs=True) # If the last check point evaluated matches max train steps, exit. if last_checkpoint is not None: last_ckpt_step = checkpoints.get_step_from_checkpoint_asset( last_checkpoint) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break # Release replicated_model_states. del replicated_model_states new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) while new_checkpoint == last_checkpoint: # Sleep for a minute. time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) # There must be a new checkpoint here. logging.info('Found new checkpoint: %s', new_checkpoint) model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) last_checkpoint = new_checkpoint
def train_and_evaluate_pmap( task_p: InstantiableParams, train_input_p: InstantiableParams, job_log_dir: Optional[str], checkpoint_manager: checkpoint_managers.CheckpointManager, restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], eval_input_p: Optional[Sequence[InstantiableParams]]) -> None: """Runs the training and evaluation loop. Args: task_p: Params for the task encapsulating the data parallel model. train_input_p: Params for the train data input pipeline. job_log_dir: Directory for the job logs. checkpoint_manager: A checkpoint manager controlling how often to save and delete checkpoints. restore_checkpoint_dir: If set, the directory from which to restore checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory instead. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. eval_input_p: Optional list of params for the eval input pipelines. """ logging.info('Using pmap for data parallelism.') if jax.config.jax_parallel_functions_output_gda: raise NotImplementedError( 'jax.pmap does not yet support GlobalDeviceArray.') jax_task = task_p.Instantiate() if eval_input_p is not None: eval_input_pipelines = [ input_p.Instantiate() for input_p in eval_input_p ] # TODO(shafey): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = _checkpoint_dir(job_log_dir) restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir model_states = trainer_lib.initialize_model_state(jax_task, init_key) model_states = checkpoints.restore_checkpoint(model_states, restore_checkpoint_dir, step=restore_checkpoint_step) total_num_params = jax_task.model.total_num_vars replicated_model_states = trainer_lib.replicate_model_state(model_states) train_p = task_p.train initial_global_step = int(jax.device_get(replicated_model_states.step)[0]) logging.info('Model initial global_step=%d', initial_global_step) _update_latest_model_step(train_input_p, initial_global_step, train_p.eval_interval_steps) train_input_pipeline = train_input_p.Instantiate() # Unreplicated model states are not needed anymore at that point. del model_states logging.info('replicated_model_states shapes: %s', jax.tree_map(lambda x: x.shape, replicated_model_states)) # From now on, different replicas should use different random seeds. # Here, each process will have its unique prng_key. # prng_key will be further split so that each core on a host will get # different prng_key. prng_key = jax.random.fold_in(prng_key, jax.process_index()) logging.info('root prng_key: %s', prng_key) fprop_dtype = task_p.model.fprop_dtype def train_step(states, prng_key, inputs): return trainer_lib.train_step_single_learner( jax_task, states, prng_key, inputs, data_parallel_axis_name='batch', fprop_dtype=fprop_dtype) def eval_step(mdl_vars, prng_key, global_step, inputs): return trainer_lib.eval_step_single_learner( jax_task, mdl_vars, prng_key, global_step, inputs, data_parallel_axis_name='batch', fprop_dtype=fprop_dtype) num_devices = jax.local_device_count() prng_key, train_key, eval_key = jax.random.split(prng_key, 3) train_prng_seed = jax.random.split(train_key, num=num_devices) eval_prng_seed = jax.random.split(eval_key, num=num_devices) logging.info('train prng_seed: %s', train_prng_seed) logging.info('eval prng_seed: %s', eval_prng_seed) p_train_step = jax.pmap(train_step, donate_argnums=(0, ), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') logging.info('Training loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_train_dir = os.path.join(summary_base_dir, 'train') summary_eval_dir = os.path.join(summary_base_dir, 'eval_train') summary_writer = summary_utils.get_summary_writer if eval_input_p is not None: summary_test_split_dirs = [ os.path.join(summary_base_dir, f'eval_test_{split}') for split, _ in enumerate(eval_input_p) ] # We either run p.eval_loop_num_batches steps or one epoch (when supported # by a resettable input) per eval loop during training. When # p.reset_for_eval is set to True, we run the eval loop until # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be # triggered either because input pipeline has reached the end of the input # sequence, or a pre-determined num_batches has reached. eval_num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p ] else: summary_test_split_dirs = [] with contextlib.ExitStack() as exit_stack: train_summary_writer = exit_stack.enter_context( summary_writer(summary_train_dir)) eval_summary_writer = exit_stack.enter_context( summary_writer(summary_eval_dir)) eval_test_summary_writers = [ exit_stack.enter_context(summary_writer(d)) for d in summary_test_split_dirs ] summary_utils.write_model_structure(train_summary_writer, replicated_model_states, is_vars_replicated=True) summary_utils.write_total_num_params(train_summary_writer, total_num_params) summary_last_time = time.time() summary_last_step = None step_i = int(jax.device_get(replicated_model_states.step)[0]) while True: logging.debug('step=`%d`: Beginning', step_i) if step_i >= train_p.num_train_steps: logging.info( 'Training loop completed (step (`%d`) greater than ' 'num_train_step (`%d`).', step_i, train_p.num_train_steps) break if summary_last_step is None: summary_last_step = step_i - 1 if checkpoint_manager.should_save(step_i): if jax.process_index() == 0: checkpoints.save_checkpoint(replicated_model_states, checkpoint_dir) checkpoint_manager.save_metadata(global_step_id=step_i) if step_i <= _N_STEPS_WARMUP_LOGGING: logging.info('step=`%d`: Retrieving model inputs.', step_i) logging.debug(' Retrieving inputs.') model_inputs = tf.nest.map_structure( py_utils.reshard, train_input_pipeline.get_next()) logging.debug(' Retrieved inputs.') logging.debug(' Performing train_step().') with jax.profiler.StepTraceAnnotation('train', step_num=step_i): (replicated_model_states, loss, metrics, per_example_out, summary_tensors) = p_train_step(replicated_model_states, train_prng_seed, model_inputs) logging.debug(' Completed train_step().') logging.debug(' Writing summaries (attempt).') if summary_utils.write_summary_every_n_steps( replicated_model_states, train_summary_writer, step_i, train_p.summary_interval_steps, loss, metrics, per_example_out, summary_tensors, train_p.norm_summary_interval_steps, summary_last_time, summary_last_step, unreplicate_mdl_vars=True, unreplicate_metrics=True): summary_last_time = time.time() summary_last_step = step_i # Synchronize step_i step_i = int(jax.device_get(replicated_model_states.step)[0]) else: # Increment locally to avoid an explicit sync. step_i += 1 logging.debug(' Wrote summaries (attempted).') # Run eval at regular step interval. if step_i % train_p.eval_interval_steps == 0: logging.debug(' Starting eval_step().') logging.debug(' Retrieving eval model_inputs.') eval_inputs = train_input_pipeline.get_next() logging.debug(' Retrieved eval model_inputs.') logging.debug( ' Performing eval_step() runs on training split.') eval_step_fn = functools.partial( p_eval_step, replicated_model_states.mdl_vars, eval_prng_seed, replicated_model_states.step) loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step( eval_inputs, eval_step_fn, reshard_inputs=True) logging.debug( ' Completed eval_step() runs on training split.') logging.info('step=`%d`', step_i) logging.info(' eval loss: %s', loss) logging.info(' mean_metrics: %s', mean_metrics) logging.info(' summary_tensors: %s', summary_tensors) if step_i % train_p.summary_interval_steps == 0: logging.debug(' Writing eval summaries.') summary_utils.write_summary_entry(eval_summary_writer, step_i, loss, mean_metrics, summary_tensors, unreplicate_metrics=True) logging.debug(' Wrote eval summaries.') # Eval on the test sets. if eval_input_p is not None: logging.debug( ' Performing eval_step() runs on test splits.') model_utils.run_eval_loop_over_test_splits( eval_num_steps, eval_step_fn, eval_test_summary_writers, step_i, eval_input_pipelines, reshard_inputs=True) logging.debug(' Completed eval_step() runs on test splits.') logging.debug('step=`%d`: End', step_i - 1)