Пример #1
0
def evaluate(
    model_name: str,
    job_log_dir: Optional[str],
    multi_host_checkpointing: Optional[bool],
    maybe_use_persistence_checkpointing: bool,
) -> None:
  """Runs the evaluation loop on the entire eval data set.

  Args:
    model_name: The name of the model from the registry to evaluate.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
  """
  model_config = model_utils.get_model(model_name)()
  task_p = model_config.task()
  model_p = task_p.model
  eval_input_p = [v for v in model_config.datasets() if not v.is_training]
  for inp in eval_input_p:
    inp.num_infeed_hosts = jax.process_count()
    inp.infeed_host_index = jax.process_index()

  if model_p.device_mesh is not None:
    checkpoint_type = checkpoints.retrieve_checkpoint_type(
        multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p)
    evaluate_spmd_model(task_p, eval_input_p, job_log_dir, checkpoint_type)
  else:
    evaluate_pmap_model(task_p, eval_input_p, job_log_dir)
Пример #2
0
def decode(
    model_name: str,
    job_log_dir: Optional[str],
    multi_host_checkpointing: Optional[bool],
    maybe_use_persistence_checkpointing: bool,
    restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    continuous_decode: bool,
) -> None:
    """Runs decoding once on the decoder datasets.

  Args:
    model_name: The name of the model from the registry to evaluate.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
    restore_checkpoint_dir: The directory from which to restore checkpoint.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    continuous_decode: whether to continuously decode on the latest ckpt.
  """
    logging.info('running decode_once on model %s restored from %s',
                 model_name, restore_checkpoint_dir)
    model_config = model_utils.get_model(model_name)()
    task_p = model_config.task()
    model_p = task_p.model
    decoder_inputs = model_config.decoder_datasets()
    if not decoder_inputs:
        return
    for inp in decoder_inputs:
        inp.num_infeed_hosts = jax.process_count()
        inp.infeed_host_index = jax.process_index()

    if model_p.device_mesh is not None:
        if continuous_decode:
            raise NotImplementedError('http://b/214589358: not supported')
        checkpoint_type = checkpoints.retrieve_checkpoint_type(
            multi_host_checkpointing, maybe_use_persistence_checkpointing,
            task_p)
        decode_once_spmd_model(task_p, decoder_inputs, job_log_dir,
                               checkpoint_type, restore_checkpoint_dir,
                               restore_checkpoint_step)
    else:
        decode_pmap_model(task_p, decoder_inputs, job_log_dir,
                          restore_checkpoint_dir, restore_checkpoint_step,
                          continuous_decode)
Пример #3
0
def train_and_evaluate(
        model_name: str,
        job_log_dir: Optional[str],
        multi_host_checkpointing: Optional[bool],
        maybe_use_persistence_checkpointing: bool,
        restore_checkpoint_dir: Optional[str],
        restore_checkpoint_step: Optional[int],
        eval_on_test: Optional[bool],
        checkpoint_todelete_subdir: Optional[str] = None) -> None:
    """Runs the training and evaluation loop.

  Args:
    model_name: The name of the model from the registry to train.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_on_test: Whether to eval on test as a part of the training loop.
    checkpoint_todelete_subdir: If set, checkpoints to be deleted will be only
      renamed into the provided subdirectory. Otherwise, they will be directly
      deleted from the file system. This is useful, when checkpoint deletion is
      time consuming.
  """
    model_config = model_utils.get_model(model_name)()
    _write_params_file(model_config, job_log_dir)
    task_p = model_config.task()

    input_p = model_config.datasets()
    # Note that we modify input params below with runtime information, therefore
    # model_config.dataset() should not be called again as it won't have the
    # correct runtime information populated.
    for inp in input_p:
        if not isinstance(inp, base_input.BaseInputParams):
            raise ValueError('Expecting BaseInputParams from datasets(), got: '
                             f'{inp.ToText()}')
        inp.num_infeed_hosts = jax.process_count()
        inp.infeed_host_index = jax.process_index()
    train_input_p = [v for v in input_p if v.is_training]
    if len(train_input_p) != 1:
        raise ValueError(
            f'Expecting exactly one training split. Got `{len(train_input_p)}`.'
        )
    train_input_p = train_input_p[0]
    logging.info('train_input_p=%s', train_input_p.ToText())
    eval_input_p = None
    if eval_on_test:
        eval_input_p = [v for v in input_p if not v.is_training]

    checkpoint_type = checkpoints.retrieve_checkpoint_type(
        multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p)

    checkpoint_manager = _create_checkpoint_manager(
        model_name, task_p, job_log_dir, checkpoint_type,
        checkpoint_todelete_subdir)

    if task_p.model.device_mesh is not None:
        train_and_evaluate_spmd_model(task_p, train_input_p, job_log_dir,
                                      checkpoint_manager, checkpoint_type,
                                      restore_checkpoint_dir,
                                      restore_checkpoint_step, eval_input_p)
    else:
        train_and_evaluate_pmap(task_p, train_input_p, job_log_dir,
                                checkpoint_manager, restore_checkpoint_dir,
                                restore_checkpoint_step, eval_input_p)