Ejemplo n.º 1
0
def decode_pmap_model(
    task_p: InstantiableParams,
    input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    continuous_decode: bool,
) -> None:
  """Runs the decoding on the entire decoder datasets for a PMAP model.

  Args:
    task_p: Params of the task encapsulating a the data parallel model.
    input_p: List of input params to be decoded.
    job_log_dir: Directory for the job logs.
    restore_checkpoint_dir: The directory from which to restore checkpoint. If
      None, uses job_log_dir.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    continuous_decode: whether to continuously decode on the latest ckpt.
  """
  if continuous_decode and restore_checkpoint_step is not None:
    raise ValueError('Continuous decoding mode requires restore_checkpoint_step'
                     '=None, actual restore_checkpoint_step='
                     f'{restore_checkpoint_step}')
  restore_checkpoint_dir = restore_checkpoint_dir or os.path.join(
      job_log_dir, 'checkpoints')

  # TODO(shafey): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  # From now on, different replicas should use different random seeds.
  # Here, each process will have its unique prng_key.
  # prng_key will be further split so that each core on a host will get
  # different prng_key.
  prng_key = jax.random.fold_in(prng_key, jax.process_index())
  logging.info('root prng_key: %s', prng_key)
  prng_key, eval_key = jax.random.split(prng_key)
  prng_seed = jax.random.split(eval_key, num=jax.local_device_count())
  logging.info('decoder prng_seed: %s', prng_seed)

  inputs = [p.Instantiate() for p in input_p]
  summary_base_dir = os.path.join(job_log_dir, 'summaries')
  dirnames = _get_dir_names(input_p)
  summary_decode_dirs = [
      os.path.join(summary_base_dir, f'decode_test_{dirnames[split]}')
      for split, _ in enumerate(input_p)
  ]
  with contextlib.ExitStack() as exit_stack:
    summary_writers = [
        exit_stack.enter_context(summary_utils.get_summary_writer(d))
        for d in summary_decode_dirs
    ]

    jax_task = task_p.Instantiate()
    model_states = trainer_lib.initialize_model_state(jax_task, init_key)
    model_states = checkpoints.restore_checkpoint(
        model_states, restore_checkpoint_dir, step=restore_checkpoint_step)
    replicated_model_states = trainer_lib.replicate_model_state(model_states)
    logging.info('replicated_model_states: %s',
                 jax.tree_map(lambda x: x.shape, replicated_model_states))
    last_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)

    while True:
      _decode_once_pmap_model(jax_task, task_p, inputs, input_p, prng_seed,
                              job_log_dir, replicated_model_states,
                              summary_writers)
      if not continuous_decode:
        break
      if last_checkpoint is not None:
        last_ckpt_step = int(last_checkpoint.split('_')[-1])
        exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
        if exceeded_ckpt >= task_p.train.num_train_steps:
          break
      # Release replicated_model_states.
      del replicated_model_states
      new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)
      while new_checkpoint == last_checkpoint:
        time.sleep(60)
        new_checkpoint = checkpoints.latest_checkpoint(restore_checkpoint_dir)
      logging.info('Found new checkpoint: %s', new_checkpoint)
      model_states = checkpoints.restore_checkpoint(model_states,
                                                    restore_checkpoint_dir)
      replicated_model_states = trainer_lib.replicate_model_state(model_states)
      last_checkpoint = new_checkpoint
Ejemplo n.º 2
0
def evaluate_pmap_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
) -> None:
  """Runs the evaluation loop on the entire test dataset for PMAP model.

  Args:
    task_p: Params for the task encapsulating the data parallel model.
    eval_input_p: List of params for the eval data input pipelines.
    job_log_dir: Directory for the job logs.
  """
  logging.info('Using pmap for data parallelism.')
  jax_task = task_p.Instantiate()
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(shafey): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  model_states = trainer_lib.initialize_model_state(jax_task, init_key)
  # Pmap does not use GDA, and so global_mesh and mesh_axes are None.
  model_states = checkpoints.restore_checkpoint(model_states, checkpoint_dir)
  replicated_model_states = trainer_lib.replicate_model_state(model_states)
  logging.info('replicated_model_states: %s',
               jax.tree_map(lambda x: x.shape, replicated_model_states))
  # From now on, different replicas should use different random seeds.
  # Here, each process will have its unique prng_key.
  # prng_key will be further split so that each core on a host will get
  # different prng_key.
  prng_key = jax.random.fold_in(prng_key, jax.process_index())
  logging.info('root prng_key: %s', prng_key)

  def eval_step(mdl_vars, prng_key, global_step, inputs):
    return trainer_lib.eval_step_single_learner(
        jax_task,
        mdl_vars,
        prng_key,
        global_step,
        inputs,
        data_parallel_axis_name='batch',
        fprop_dtype=jax_task.model.fprop_dtype)

  num_devices = jax.local_device_count()
  prng_key, eval_key = jax.random.split(prng_key)
  eval_prng_seed = jax.random.split(eval_key, num=num_devices)
  logging.info('eval prng_seed: %s', eval_prng_seed)

  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  logging.info('Evaluation loop starting...')
  summary_base_dir = os.path.join(job_log_dir, 'summaries')
  summary_eval_dirs = [
      os.path.join(summary_base_dir, f'eval_test_{split}')
      for split, _ in enumerate(eval_input_p)
  ]

  num_steps = [
      -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p
  ]
  last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
  with contextlib.ExitStack() as exit_stack:
    eval_summary_writers = [
        exit_stack.enter_context(summary_utils.get_summary_writer(d))
        for d in summary_eval_dirs
    ]

    while True:
      step_i = int(jax.device_get(replicated_model_states.step)[0])
      eval_step = functools.partial(p_eval_step,
                                    maybe_ema(replicated_model_states),
                                    eval_prng_seed,
                                    replicated_model_states.step)
      # Run the eval loop.
      model_utils.run_eval_loop_over_test_splits(
          num_steps,
          eval_step,
          eval_summary_writers,
          step_i,
          eval_input_pipelines,
          reshard_inputs=True)
      # If the last check point evaluated matches max train steps, exit.
      if last_checkpoint is not None:
        last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
            last_checkpoint)
        exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
        if exceeded_ckpt >= task_p.train.num_train_steps:
          break
      # Release replicated_model_states.
      del replicated_model_states
      new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
      while new_checkpoint == last_checkpoint:
        # Sleep for a minute.
        time.sleep(60)
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
      # There must be a new checkpoint here.
      logging.info('Found new checkpoint: %s', new_checkpoint)
      model_states = checkpoints.restore_checkpoint(model_states,
                                                    checkpoint_dir)
      replicated_model_states = trainer_lib.replicate_model_state(model_states)
      last_checkpoint = new_checkpoint
Ejemplo n.º 3
0
def train_and_evaluate_pmap(
        task_p: InstantiableParams, train_input_p: InstantiableParams,
        job_log_dir: Optional[str],
        checkpoint_manager: checkpoint_managers.CheckpointManager,
        restore_checkpoint_dir: Optional[str],
        restore_checkpoint_step: Optional[int],
        eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
    """Runs the training and evaluation loop.

  Args:
    task_p: Params for the task encapsulating the data parallel model.
    train_input_p: Params for the train data input pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
    logging.info('Using pmap for data parallelism.')
    if jax.config.jax_parallel_functions_output_gda:
        raise NotImplementedError(
            'jax.pmap does not yet support GlobalDeviceArray.')
    jax_task = task_p.Instantiate()

    if eval_input_p is not None:
        eval_input_pipelines = [
            input_p.Instantiate() for input_p in eval_input_p
        ]

    # TODO(shafey): Retrieve the seeds from the model definition instead.
    prng_key = jax.random.PRNGKey(1234)
    prng_key, init_key = jax.random.split(prng_key)

    checkpoint_dir = _checkpoint_dir(job_log_dir)
    restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
    model_states = trainer_lib.initialize_model_state(jax_task, init_key)
    model_states = checkpoints.restore_checkpoint(model_states,
                                                  restore_checkpoint_dir,
                                                  step=restore_checkpoint_step)
    total_num_params = jax_task.model.total_num_vars
    replicated_model_states = trainer_lib.replicate_model_state(model_states)

    train_p = task_p.train
    initial_global_step = int(jax.device_get(replicated_model_states.step)[0])
    logging.info('Model initial global_step=%d', initial_global_step)
    _update_latest_model_step(train_input_p, initial_global_step,
                              train_p.eval_interval_steps)
    train_input_pipeline = train_input_p.Instantiate()

    # Unreplicated model states are not needed anymore at that point.
    del model_states

    logging.info('replicated_model_states shapes: %s',
                 jax.tree_map(lambda x: x.shape, replicated_model_states))
    # From now on, different replicas should use different random seeds.
    # Here, each process will have its unique prng_key.
    # prng_key will be further split so that each core on a host will get
    # different prng_key.
    prng_key = jax.random.fold_in(prng_key, jax.process_index())
    logging.info('root prng_key: %s', prng_key)

    fprop_dtype = task_p.model.fprop_dtype

    def train_step(states, prng_key, inputs):
        return trainer_lib.train_step_single_learner(
            jax_task,
            states,
            prng_key,
            inputs,
            data_parallel_axis_name='batch',
            fprop_dtype=fprop_dtype)

    def eval_step(mdl_vars, prng_key, global_step, inputs):
        return trainer_lib.eval_step_single_learner(
            jax_task,
            mdl_vars,
            prng_key,
            global_step,
            inputs,
            data_parallel_axis_name='batch',
            fprop_dtype=fprop_dtype)

    num_devices = jax.local_device_count()
    prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
    train_prng_seed = jax.random.split(train_key, num=num_devices)
    eval_prng_seed = jax.random.split(eval_key, num=num_devices)
    logging.info('train prng_seed: %s', train_prng_seed)
    logging.info('eval prng_seed: %s', eval_prng_seed)

    p_train_step = jax.pmap(train_step,
                            donate_argnums=(0, ),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    logging.info('Training loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_train_dir = os.path.join(summary_base_dir, 'train')
    summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
    summary_writer = summary_utils.get_summary_writer
    if eval_input_p is not None:
        summary_test_split_dirs = [
            os.path.join(summary_base_dir, f'eval_test_{split}')
            for split, _ in enumerate(eval_input_p)
        ]
        # We either run p.eval_loop_num_batches steps or one epoch (when supported
        # by a resettable input) per eval loop during training. When
        # p.reset_for_eval is set to True, we run the eval loop until
        # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
        # triggered either because input pipeline has reached the end of the input
        # sequence, or a pre-determined num_batches has reached.
        eval_num_steps = [
            -1 if p.reset_for_eval else p.eval_loop_num_batches
            for p in eval_input_p
        ]
    else:
        summary_test_split_dirs = []

    with contextlib.ExitStack() as exit_stack:
        train_summary_writer = exit_stack.enter_context(
            summary_writer(summary_train_dir))
        eval_summary_writer = exit_stack.enter_context(
            summary_writer(summary_eval_dir))
        eval_test_summary_writers = [
            exit_stack.enter_context(summary_writer(d))
            for d in summary_test_split_dirs
        ]

        summary_utils.write_model_structure(train_summary_writer,
                                            replicated_model_states,
                                            is_vars_replicated=True)
        summary_utils.write_total_num_params(train_summary_writer,
                                             total_num_params)

        summary_last_time = time.time()
        summary_last_step = None

        step_i = int(jax.device_get(replicated_model_states.step)[0])
        while True:
            logging.debug('step=`%d`: Beginning', step_i)
            if step_i >= train_p.num_train_steps:
                logging.info(
                    'Training loop completed (step (`%d`) greater than '
                    'num_train_step (`%d`).', step_i, train_p.num_train_steps)
                break
            if summary_last_step is None:
                summary_last_step = step_i - 1

            if checkpoint_manager.should_save(step_i):
                if jax.process_index() == 0:
                    checkpoints.save_checkpoint(replicated_model_states,
                                                checkpoint_dir)
                checkpoint_manager.save_metadata(global_step_id=step_i)

            if step_i <= _N_STEPS_WARMUP_LOGGING:
                logging.info('step=`%d`: Retrieving model inputs.', step_i)
            logging.debug('  Retrieving inputs.')
            model_inputs = tf.nest.map_structure(
                py_utils.reshard, train_input_pipeline.get_next())
            logging.debug('  Retrieved inputs.')
            logging.debug('  Performing train_step().')
            with jax.profiler.StepTraceAnnotation('train', step_num=step_i):
                (replicated_model_states, loss, metrics, per_example_out,
                 summary_tensors) = p_train_step(replicated_model_states,
                                                 train_prng_seed, model_inputs)
            logging.debug('  Completed train_step().')

            logging.debug('  Writing summaries (attempt).')
            if summary_utils.write_summary_every_n_steps(
                    replicated_model_states,
                    train_summary_writer,
                    step_i,
                    train_p.summary_interval_steps,
                    loss,
                    metrics,
                    per_example_out,
                    summary_tensors,
                    train_p.norm_summary_interval_steps,
                    summary_last_time,
                    summary_last_step,
                    unreplicate_mdl_vars=True,
                    unreplicate_metrics=True):
                summary_last_time = time.time()
                summary_last_step = step_i
                # Synchronize step_i
                step_i = int(jax.device_get(replicated_model_states.step)[0])
            else:
                # Increment locally to avoid an explicit sync.
                step_i += 1
            logging.debug('  Wrote summaries (attempted).')

            # Run eval at regular step interval.
            if step_i % train_p.eval_interval_steps == 0:
                logging.debug('  Starting eval_step().')
                logging.debug('  Retrieving eval model_inputs.')
                eval_inputs = train_input_pipeline.get_next()
                logging.debug('  Retrieved eval model_inputs.')
                logging.debug(
                    '  Performing eval_step() runs on training split.')
                eval_step_fn = functools.partial(
                    p_eval_step, replicated_model_states.mdl_vars,
                    eval_prng_seed, replicated_model_states.step)
                loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
                    eval_inputs, eval_step_fn, reshard_inputs=True)
                logging.debug(
                    '  Completed eval_step() runs on training split.')
                logging.info('step=`%d`', step_i)
                logging.info('  eval loss: %s', loss)
                logging.info('  mean_metrics: %s', mean_metrics)
                logging.info('  summary_tensors: %s', summary_tensors)
                if step_i % train_p.summary_interval_steps == 0:
                    logging.debug('  Writing eval summaries.')
                    summary_utils.write_summary_entry(eval_summary_writer,
                                                      step_i,
                                                      loss,
                                                      mean_metrics,
                                                      summary_tensors,
                                                      unreplicate_metrics=True)
                    logging.debug('  Wrote eval summaries.')
                # Eval on the test sets.
                if eval_input_p is not None:
                    logging.debug(
                        '  Performing eval_step() runs on test splits.')
                    model_utils.run_eval_loop_over_test_splits(
                        eval_num_steps,
                        eval_step_fn,
                        eval_test_summary_writers,
                        step_i,
                        eval_input_pipelines,
                        reshard_inputs=True)
                logging.debug('  Completed eval_step() runs on test splits.')
            logging.debug('step=`%d`: End', step_i - 1)