def decode_pmap_model( task_p: InstantiableParams, input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], continuous_decode: bool, ) -> None: """Runs the decoding on the entire decoder datasets for a PMAP model. Args: task_p: Params of the task encapsulating a the data parallel model. input_p: List of input params to be decoded. job_log_dir: Directory for the job logs. restore_checkpoint_dir: The directory from which to restore checkpoint. If None, uses job_log_dir. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. continuous_decode: whether to continuously decode on the latest ckpt. """ if continuous_decode and restore_checkpoint_step is not None: raise ValueError('Continuous decoding mode requires restore_checkpoint_step' '=None, actual restore_checkpoint_step=' f'{restore_checkpoint_step}') restore_checkpoint_dir = restore_checkpoint_dir or os.path.join( job_log_dir, 'checkpoints') # TODO(shafey): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) # From now on, different replicas should use different random seeds. # Here, each process will have its unique prng_key. # prng_key will be further split so that each core on a host will get # different prng_key. prng_key = jax.random.fold_in(prng_key, jax.process_index()) logging.info('root prng_key: %s', prng_key) prng_key, eval_key = jax.random.split(prng_key) prng_seed = jax.random.split(eval_key, num=jax.local_device_count()) logging.info('decoder prng_seed: %s', prng_seed) inputs = [p.Instantiate() for p in input_p] summary_base_dir = os.path.join(job_log_dir, 'summaries') dirnames = _get_dir_names(input_p) summary_decode_dirs = [ os.path.join(summary_base_dir, f'decode_test_{dirnames[split]}') for split, _ in enumerate(input_p) ] with contextlib.ExitStack() as exit_stack: summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_decode_dirs ] jax_task = task_p.Instantiate() model_states = trainer_lib.initialize_model_state(jax_task, init_key) model_states = checkpoints.restore_checkpoint( model_states, restore_checkpoint_dir, step=restore_checkpoint_step) replicated_model_states = trainer_lib.replicate_model_state(model_states) logging.info('replicated_model_states: %s', jax.tree_map(lambda x: x.shape, replicated_model_states)) last_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) while True: _decode_once_pmap_model(jax_task, task_p, inputs, input_p, prng_seed, job_log_dir, replicated_model_states, summary_writers) if not continuous_decode: break if last_checkpoint is not None: last_ckpt_step = int(last_checkpoint.split('_')[-1]) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break # Release replicated_model_states. del replicated_model_states new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) while new_checkpoint == last_checkpoint: time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir) logging.info('Found new checkpoint: %s', new_checkpoint) model_states = checkpoints.restore_checkpoint(model_states, restore_checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) last_checkpoint = new_checkpoint
def evaluate_pmap_model( task_p: InstantiableParams, eval_input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], ) -> None: """Runs the evaluation loop on the entire test dataset for PMAP model. Args: task_p: Params for the task encapsulating the data parallel model. eval_input_p: List of params for the eval data input pipelines. job_log_dir: Directory for the job logs. """ logging.info('Using pmap for data parallelism.') jax_task = task_p.Instantiate() eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p] # TODO(shafey): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = os.path.join(job_log_dir, 'checkpoints') model_states = trainer_lib.initialize_model_state(jax_task, init_key) # Pmap does not use GDA, and so global_mesh and mesh_axes are None. model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) logging.info('replicated_model_states: %s', jax.tree_map(lambda x: x.shape, replicated_model_states)) # From now on, different replicas should use different random seeds. # Here, each process will have its unique prng_key. # prng_key will be further split so that each core on a host will get # different prng_key. prng_key = jax.random.fold_in(prng_key, jax.process_index()) logging.info('root prng_key: %s', prng_key) def eval_step(mdl_vars, prng_key, global_step, inputs): return trainer_lib.eval_step_single_learner( jax_task, mdl_vars, prng_key, global_step, inputs, data_parallel_axis_name='batch', fprop_dtype=jax_task.model.fprop_dtype) num_devices = jax.local_device_count() prng_key, eval_key = jax.random.split(prng_key) eval_prng_seed = jax.random.split(eval_key, num=num_devices) logging.info('eval prng_seed: %s', eval_prng_seed) p_eval_step = jax.pmap(eval_step, axis_name='batch') logging.info('Evaluation loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_eval_dirs = [ os.path.join(summary_base_dir, f'eval_test_{split}') for split, _ in enumerate(eval_input_p) ] num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p ] last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) with contextlib.ExitStack() as exit_stack: eval_summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_eval_dirs ] while True: step_i = int(jax.device_get(replicated_model_states.step)[0]) eval_step = functools.partial(p_eval_step, maybe_ema(replicated_model_states), eval_prng_seed, replicated_model_states.step) # Run the eval loop. model_utils.run_eval_loop_over_test_splits( num_steps, eval_step, eval_summary_writers, step_i, eval_input_pipelines, reshard_inputs=True) # If the last check point evaluated matches max train steps, exit. if last_checkpoint is not None: last_ckpt_step = checkpoints.get_step_from_checkpoint_asset( last_checkpoint) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break # Release replicated_model_states. del replicated_model_states new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) while new_checkpoint == last_checkpoint: # Sleep for a minute. time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) # There must be a new checkpoint here. logging.info('Found new checkpoint: %s', new_checkpoint) model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir) replicated_model_states = trainer_lib.replicate_model_state(model_states) last_checkpoint = new_checkpoint
def evaluate_spmd_model( task_p: InstantiableParams, eval_input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], checkpoint_type: CheckpointType, ) -> None: """Runs the evaluation loop on the entire test dataset for SPMD model. Args: task_p: Params of the task encapsulating an SPMD model. eval_input_p: List of Params for the eval data pipelines. job_log_dir: Directory for the job logs. checkpoint_type: Type of model checkpointing method to use. """ logging.info('Using SPMD sharding for model parallelism.') eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p] # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = os.path.join(job_log_dir, 'checkpoints') # Note that GDA checkpoint requires all processes to participate in # checkpointing but it does not require a separate checkpoint_dir per process. if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: checkpoint_task_dir = os.path.join(checkpoint_dir, f'{jax.process_index():03d}') else: checkpoint_task_dir = checkpoint_dir multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) def get_shape_dtype(x): y = jax.ShapeDtypeStruct(x.shape, x.dtype) return y # Do not ues eval_input_pipelines[0] directly. sample_model_inputs = eval_input_p[0].Instantiate().get_next() inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs) model_p = task_p.model mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = ( trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape)) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) logging.info('partitioned_train_state: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) if multi_host_checkpointing: py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}') # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, eval_key = jax.random.split(prng_key) logging.info('eval prng_key: %s', eval_key) logging.info('Evaluation loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_eval_dirs = [ os.path.join(summary_base_dir, f'eval_{split}') for split, _ in enumerate(eval_input_p) ] num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p] last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) with contextlib.ExitStack() as exit_stack: eval_summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_eval_dirs ] while True: step_i = int(jax.device_get(partitioned_train_state.step)) eval_step_fn = functools.partial(eval_step, partitioned_train_state.mdl_vars, eval_key, partitioned_train_state.step) # Run the eval loop. model_utils.run_eval_loop_over_test_splits( num_steps, eval_step_fn, eval_summary_writers, step_i, eval_input_pipelines, eval_inputs_partition_specs, inputs_shape, global_mesh, reshard_inputs=False) # If the last check point evaluated matches max train steps, exit. if last_checkpoint is not None: last_ckpt_step = checkpoints.get_step_from_checkpoint_asset( last_checkpoint) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) while new_checkpoint == last_checkpoint: # Sleep for a minute. time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) # There must be a new checkpoint here. logging.info('Found new checkpoint: %s', new_checkpoint) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{checkpoint_dir}') last_checkpoint = new_checkpoint