Exemple #1
0
def decode_pmap_model(
    task_p: InstantiableParams,
    input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    continuous_decode: bool,
) -> None:
  """Runs the decoding on the entire decoder datasets for a PMAP model.

  Args:
    task_p: Params of the task encapsulating a the data parallel model.
    input_p: List of input params to be decoded.
    job_log_dir: Directory for the job logs.
    restore_checkpoint_dir: The directory from which to restore checkpoint. If
      None, uses job_log_dir.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    continuous_decode: whether to continuously decode on the latest ckpt.
  """
  if continuous_decode and restore_checkpoint_step is not None:
    raise ValueError('Continuous decoding mode requires restore_checkpoint_step'
                     '=None, actual restore_checkpoint_step='
                     f'{restore_checkpoint_step}')
  restore_checkpoint_dir = restore_checkpoint_dir or os.path.join(
      job_log_dir, 'checkpoints')

  # TODO(shafey): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  # From now on, different replicas should use different random seeds.
  # Here, each process will have its unique prng_key.
  # prng_key will be further split so that each core on a host will get
  # different prng_key.
  prng_key = jax.random.fold_in(prng_key, jax.process_index())
  logging.info('root prng_key: %s', prng_key)
  prng_key, eval_key = jax.random.split(prng_key)
  prng_seed = jax.random.split(eval_key, num=jax.local_device_count())
  logging.info('decoder prng_seed: %s', prng_seed)

  inputs = [p.Instantiate() for p in input_p]
  summary_base_dir = os.path.join(job_log_dir, 'summaries')
  dirnames = _get_dir_names(input_p)
  summary_decode_dirs = [
      os.path.join(summary_base_dir, f'decode_test_{dirnames[split]}')
      for split, _ in enumerate(input_p)
  ]
  with contextlib.ExitStack() as exit_stack:
    summary_writers = [
        exit_stack.enter_context(summary_utils.get_summary_writer(d))
        for d in summary_decode_dirs
    ]

    jax_task = task_p.Instantiate()
    model_states = trainer_lib.initialize_model_state(jax_task, init_key)
    model_states = checkpoints.restore_checkpoint(
        model_states, restore_checkpoint_dir, step=restore_checkpoint_step)
    replicated_model_states = trainer_lib.replicate_model_state(model_states)
    logging.info('replicated_model_states: %s',
                 jax.tree_map(lambda x: x.shape, replicated_model_states))
    last_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)

    while True:
      _decode_once_pmap_model(jax_task, task_p, inputs, input_p, prng_seed,
                              job_log_dir, replicated_model_states,
                              summary_writers)
      if not continuous_decode:
        break
      if last_checkpoint is not None:
        last_ckpt_step = int(last_checkpoint.split('_')[-1])
        exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
        if exceeded_ckpt >= task_p.train.num_train_steps:
          break
      # Release replicated_model_states.
      del replicated_model_states
      new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)
      while new_checkpoint == last_checkpoint:
        time.sleep(60)
        new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)
      logging.info('Found new checkpoint: %s', new_checkpoint)
      model_states = checkpoints.restore_checkpoint(model_states,
                                                    restore_checkpoint_dir)
      replicated_model_states = trainer_lib.replicate_model_state(model_states)
      last_checkpoint = new_checkpoint
Exemple #2
0
def evaluate_pmap_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
) -> None:
  """Runs the evaluation loop on the entire test dataset for PMAP model.

  Args:
    task_p: Params for the task encapsulating the data parallel model.
    eval_input_p: List of params for the eval data input pipelines.
    job_log_dir: Directory for the job logs.
  """
  logging.info('Using pmap for data parallelism.')
  jax_task = task_p.Instantiate()
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(shafey): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  model_states = trainer_lib.initialize_model_state(jax_task, init_key)
  # Pmap does not use GDA, and so global_mesh and mesh_axes are None.
  model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir)
  replicated_model_states = trainer_lib.replicate_model_state(model_states)
  logging.info('replicated_model_states: %s',
               jax.tree_map(lambda x: x.shape, replicated_model_states))
  # From now on, different replicas should use different random seeds.
  # Here, each process will have its unique prng_key.
  # prng_key will be further split so that each core on a host will get
  # different prng_key.
  prng_key = jax.random.fold_in(prng_key, jax.process_index())
  logging.info('root prng_key: %s', prng_key)

  def eval_step(mdl_vars, prng_key, global_step, inputs):
    return trainer_lib.eval_step_single_learner(
        jax_task,
        mdl_vars,
        prng_key,
        global_step,
        inputs,
        data_parallel_axis_name='batch',
        fprop_dtype=jax_task.model.fprop_dtype)

  num_devices = jax.local_device_count()
  prng_key, eval_key = jax.random.split(prng_key)
  eval_prng_seed = jax.random.split(eval_key, num=num_devices)
  logging.info('eval prng_seed: %s', eval_prng_seed)

  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  logging.info('Evaluation loop starting...')
  summary_base_dir = os.path.join(job_log_dir, 'summaries')
  summary_eval_dirs = [
      os.path.join(summary_base_dir, f'eval_test_{split}')
      for split, _ in enumerate(eval_input_p)
  ]

  num_steps = [
      -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p
  ]
  last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
  with contextlib.ExitStack() as exit_stack:
    eval_summary_writers = [
        exit_stack.enter_context(summary_utils.get_summary_writer(d))
        for d in summary_eval_dirs
    ]

    while True:
      step_i = int(jax.device_get(replicated_model_states.step)[0])
      eval_step = functools.partial(p_eval_step,
                                    maybe_ema(replicated_model_states),
                                    eval_prng_seed,
                                    replicated_model_states.step)
      # Run the eval loop.
      model_utils.run_eval_loop_over_test_splits(
          num_steps,
          eval_step,
          eval_summary_writers,
          step_i,
          eval_input_pipelines,
          reshard_inputs=True)
      # If the last check point evaluated matches max train steps, exit.
      if last_checkpoint is not None:
        last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
            last_checkpoint)
        exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
        if exceeded_ckpt >= task_p.train.num_train_steps:
          break
      # Release replicated_model_states.
      del replicated_model_states
      new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
      while new_checkpoint == last_checkpoint:
        # Sleep for a minute.
        time.sleep(60)
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
      # There must be a new checkpoint here.
      logging.info('Found new checkpoint: %s', new_checkpoint)
      model_states = checkpoints.restore_checkpoint(model_states,
                                                    checkpoint_dir)
      replicated_model_states = trainer_lib.replicate_model_state(model_states)
      last_checkpoint = new_checkpoint
Exemple #3
0
def evaluate_spmd_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
) -> None:
  """Runs the evaluation loop on the entire test dataset for SPMD model.

  Args:
    task_p: Params of the task encapsulating an SPMD model.
    eval_input_p: List of Params for the eval data pipelines.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
  """
  logging.info('Using SPMD sharding for model parallelism.')
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  # Note that GDA checkpoint requires all processes to participate in
  # checkpointing but it does not require a separate checkpoint_dir per process.
  if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
    checkpoint_task_dir = os.path.join(checkpoint_dir,
                                       f'{jax.process_index():03d}')
  else:
    checkpoint_task_dir = checkpoint_dir

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  def get_shape_dtype(x):
    y = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return y

  # Do not ues eval_input_pipelines[0] directly.
  sample_model_inputs = eval_input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs)

  model_p = task_p.model
  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = (
        trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape))
    partitioned_train_state = checkpoints.restore_checkpoint(
        partitioned_train_state,
        checkpoint_task_dir,
        global_mesh=global_mesh,
        checkpoint_type=checkpoint_type,
        state_specs=partitioned_specs)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    if multi_host_checkpointing:
      py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')

    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, eval_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', eval_key)

    logging.info('Evaluation loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_eval_dirs = [
        os.path.join(summary_base_dir, f'eval_{split}')
        for split, _ in enumerate(eval_input_p)
    ]

    num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p]
    last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
    with contextlib.ExitStack() as exit_stack:
      eval_summary_writers = [
          exit_stack.enter_context(summary_utils.get_summary_writer(d))
          for d in summary_eval_dirs
      ]
      while True:
        step_i = int(jax.device_get(partitioned_train_state.step))
        eval_step_fn = functools.partial(eval_step,
                                         partitioned_train_state.mdl_vars,
                                         eval_key, partitioned_train_state.step)
        # Run the eval loop.
        model_utils.run_eval_loop_over_test_splits(
            num_steps,
            eval_step_fn,
            eval_summary_writers,
            step_i,
            eval_input_pipelines,
            eval_inputs_partition_specs,
            inputs_shape,
            global_mesh,
            reshard_inputs=False)
        # If the last check point evaluated matches max train steps, exit.
        if last_checkpoint is not None:
          last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
              last_checkpoint)
          exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
          if exceeded_ckpt >= task_p.train.num_train_steps:
            break
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        while new_checkpoint == last_checkpoint:
          # Sleep for a minute.
          time.sleep(60)
          new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        # There must be a new checkpoint here.
        logging.info('Found new checkpoint: %s', new_checkpoint)
        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=partitioned_specs)
        if multi_host_checkpointing:
          py_utils.sync_global_devices(
              f'checkpointer:restored:{checkpoint_dir}')
        last_checkpoint = new_checkpoint