Exemplo n.º 1
0
def get_metrics(config, metrics):
  if not config.debug_run:
    metrics = common_utils.get_metrics(metrics)
  else:
    metrics = common_utils.stack_forest(metrics)
    metrics = jax.device_get(metrics)
  return metrics
Exemplo n.º 2
0
    def maybe_eval_and_log(self, eval_summary, master, step, tick,
                           train_metrics, train_summary):
        """Maybe evaluate and log based on the current step value."""
        if (step % self.eval_frequency == 0) or (step == self.total_steps):
            del eval_summary
            del train_summary

            train_metrics = common_utils.get_metrics(train_metrics)
            train_summary = pipeline_utils.compute_global_mean_metrics(
                train_metrics)

            tock = time.time()
            steps_per_sec = self.eval_frequency / (tock - tick)
            tick = tock

            # log train summary
            if master:
                self.write_train_summary(step=step,
                                         metric_dict=train_metrics,
                                         summary=train_summary,
                                         steps_per_sec=steps_per_sec)
            # reset metric accumulation for next evaluation cycle
            del train_metrics
            train_metrics = []

            # sync model state across replicas
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)

            # evaluate and log the results
            eval_summary, _ = self.eval(step, self.train_state)
        return eval_summary, train_metrics, train_summary, tick
Exemplo n.º 3
0
def test(optimizer, state, p_eval_step, step, test_ds, summary_writer):
    """Test the flax module in optimizer on test_ds.

  Args:
    optimizer: flax optimizer (contains flax module).
    state: model state, e.g. batch statistics.
    p_eval_step: fn; Pmapped evaluation step function.
    step: int; Number of training steps passed so far.
    test_ds: tf.dataset; Test dataset.
    summary_writer: tensorflow summary writer.
  """
    # Test Metrics
    test_metrics = []
    test_iter = iter(test_ds)
    for _, test_batch in zip(itertools.repeat(1), test_iter):
        # pylint: disable=protected-access
        test_batch = common_utils.shard(
            jax.tree_map(lambda x: x._numpy(), test_batch))
        # pylint: enable=protected-access
        metrics = p_eval_step(optimizer.target, state, test_batch)
        test_metrics.append(metrics)
    test_metrics = common_utils.get_metrics(test_metrics)
    test_metrics_sums = jax.tree_map(jnp.sum, test_metrics)
    test_denominator = test_metrics_sums.pop('denominator')
    test_summary = jax.tree_map(
        lambda x: x / test_denominator,  # pylint: disable=cell-var-from-loop
        test_metrics_sums)
    logging.info('test in step: %d, loss: %.4f, acc: %.4f', step,
                 test_summary['loss'], test_summary['accuracy'])
    if jax.host_id() == 0:
        for key, val in test_summary.items():
            summary_writer.scalar(f'test_{key}', val, step)
        summary_writer.flush()
Exemplo n.º 4
0
def combine_metrics(step_metrics):
    """Given a list of metric dicts, combine to a single summary metrics dict.

  Args:
    step_metrics: A dict with (metric name, metric value) items. Contains summed
      metrics and the corresponding denominator (the number of next-token
      prediction instances). Each metric value have at least one dimension.

  Returns:
    A dict with (metric name, metric value) items containing combined metrics.
  """
    metrics_all = common_utils.get_metrics(step_metrics)
    lr = None
    if 'learning_rate' in metrics_all:
        lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    if lr is not None:
        summary['learning_rate'] = lr

    # Calculate (clipped) perplexity after averaging log-perplexities:
    if 'loss' in summary:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    return summary
Exemplo n.º 5
0
def predict_once(run_configuration, optimizer=None):
    """Predict the result once for each element in the dataset."""
    adapter = run_configuration.adapter
    checkpoint_path = run_configuration.original_checkpoint_path
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    predict_step = adapter.make_predict_step()
    predict_step_parallel = jax.pmap(predict_step, axis_name='batch')

    # Perform inference
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    metrics_all = []
    for example in itertools.islice(dataset_iter, 200):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = predict_step_parallel(optimizer.target,
                                                       train_inputs)
        adapter.handle_predict(metrics, logits, state)
        metrics_all.append(metrics)
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics = jax.tree_map(jnp.sum, metrics_all)
    return metrics
Exemplo n.º 6
0
def train_for_one_epoch(
    dataset_source: dataset_source_lib.DatasetSource,
    optimizer: flax.optim.Optimizer, state: flax.nn.Collection,
    prng_key: jnp.ndarray, pmapped_train_step: _TrainStep,
    pmapped_update_ema: Optional[_EMAUpdateStep],
    moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage],
    summary_writer: tensorboard.SummaryWriter
) -> Tuple[flax.optim.Optimizer, flax.nn.Collection,
           Optional[efficientnet_optim.ExponentialMovingAverage]]:
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    pmapped_update_ema: Function to update the parameter moving average. Can be
      None if we don't use EMA.
    moving_averages: Parameters moving average if used.
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  start_time = time.time()
  cnt = 0
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key = jax.random.fold_in(prng_key, optimizer.state.step[0])
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    cnt += 1

    if moving_averages is not None:
      moving_averages = pmapped_update_ema(optimizer, state, moving_averages)

    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  info = 'Whole training step done in {} ({} steps)'.format(
      time.time()-start_time, cnt)
  logging.info(info)
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, moving_averages
    def eval_split(self, train_state, split_name, eval_env_ids=None):
        """Evaluation loop on the specified split.

    Args:
      train_state: TrainState; Object containing training state.
      split_name: str; Name of the data split we want to evaluate the model on.
      eval_env_ids: list(int); Eval environments ids.

    Returns:
      eval_summary, train_state
    """
        data_iters = self.task.dataset.data_iters[split_name]
        if eval_env_ids is None:
            eval_env_ids = list(map(int, data_iters.keys()))

        eval_metrics = {}
        if isinstance(self.steps_per_eval, dict):
            for env_id in eval_env_ids:
                env_id_str = str(env_id)
                env_eval_metrics = []
                for _ in range(self.steps_per_eval[split_name][env_id_str]):
                    env_eval_batches = self.get_next_batch(
                        [data_iters[env_id_str]])
                    e_metrics = self.pmapped_eval_step(train_state,
                                                       env_eval_batches,
                                                       env_id)
                    env_eval_metrics.append(e_metrics)

                env_eval_metrics = common_utils.get_metrics(env_eval_metrics)
                eval_metrics.update(env_eval_metrics)

            eval_summary = pipeline_utils.compute_global_mean_metrics(
                eval_metrics)
        else:
            _, data_iters = list(zip(*dict(data_iters).items()))
            eval_metrics = []
            for _ in range(self.steps_per_eval):
                env_eval_batches = self.get_next_batch(data_iters)
                e_metrics = self.pmapped_eval_step(train_state,
                                                   env_eval_batches, -1)
                eval_metrics.append(e_metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_summary = pipeline_utils.compute_global_mean_metrics(
                eval_metrics)

        return eval_summary, eval_metrics
Exemplo n.º 8
0
def write_train_metric(summary_writer, train_metrics, train_time, step):
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
Exemplo n.º 9
0
def eval_once(run_configuration, checkpoint_path, optimizer=None):
    """Evaluates a single checkpoint on a single epoch of data."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    eval_name = config.eval_name or 'eval'
    log_dir = os.path.join(run_dir, eval_name)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)
    step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    eval_step = adapter.make_eval_step()
    eval_step_parallel = jax.pmap(eval_step, axis_name='batch')

    # Perform evaluation
    tick = time.time()
    metrics_all = []

    example = None
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    for unused_eval_step, example in zip(range(config.eval_steps),
                                         dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = eval_step_parallel(optimizer.target,
                                                    train_inputs)
        metrics_all.append(metrics)

    # Write results.
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss'])
    if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = len(metrics_all) / (tock - tick)
        examples_per_sec = denominator / (tock - tick)
        summary_writer.scalar('per-second/steps', steps_per_sec, step)
        summary_writer.scalar('per-second/examples', examples_per_sec, step)
        for key, val in summary.items():
            summary_writer.scalar(key, val, step)

        adapter.write_summaries(example, logits, summary_writer, info, step,
                                state)
        summary_writer.flush()
Exemplo n.º 10
0
    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)
Exemplo n.º 11
0
def evaluate(p_eval_step, state, eval_ds, num_eval_steps=-1):
    """Evaluate on the given dataset."""
    logging.info('Starting evaluating.')
    eval_metrics = []
    for step, batch in enumerate(eval_ds):
        batch = jax.tree_map(np.asarray, batch)
        metrics = p_eval_step(batch=batch, state=state)
        eval_metrics.append(metrics)
        if num_eval_steps > 0 and step + 1 == num_eval_steps:
            break
    eval_metrics = common_utils.get_metrics(eval_metrics)
    summary = train_utils.metrics_summary(eval_metrics, 'eval')
    return summary
Exemplo n.º 12
0
def evaluate(p_eval_step, params, eval_ds, rng):
  """Evaluate the target and return a dictionary with the metrics."""
  logging.info('Gathering evaluation metrics.')
  eval_metrics = []

  for eval_batch in eval_ds:
    eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
    eval_batch = common_utils.shard(eval_batch)
    metrics, rng = p_eval_step(rng, params, eval_batch)
    eval_metrics.append(metrics)
  eval_metrics = common_utils.get_metrics(eval_metrics)
  eval_summary = jax.tree_map(np.mean, eval_metrics)
  return eval_summary, rng
Exemplo n.º 13
0
def eval_policy(policy, rng, state, model, test_ds):
  """Evaluate the target with policy and return a dictionary with the metrics."""
  eval_metrics = []

  for eval_batch in test_ds:
    eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
    eval_batch = common_utils.shard(eval_batch)
    metrics, rng = eval_step_policy(rng, eval_batch, state, model, policy)

    # Better to leave metrics on device, and off-load after finishing epoch.
    eval_metrics.append(metrics)

  eval_metrics = common_utils.get_metrics(eval_metrics)
  eval_summary = jax.tree_map(np.mean, eval_metrics)
  return eval_summary
Exemplo n.º 14
0
def write_train_metric(train_metrics, train_time, step):
    train_metrics = get_metrics(train_metrics)

    num_steps = len(list(train_metrics.values())[0])
    for key, vals in train_metrics.items():
        assert len(vals) == num_steps

    train_metrics_by_step = [{} for _ in range(num_steps)]
    for key, vals in train_metrics.items():
        for i, val in enumerate(vals):
            train_metrics_by_step[i][f"train_{key}"] = val

    for i in range(num_steps):
        wandb.log(train_metrics_by_step[i], step=step - num_steps + i + 1)

    wandb.log({"train_time": train_time}, step=step)
Exemplo n.º 15
0
def evaluate(*, p_eval_step, target, eval_ds):
  """Evaluate the target an return a dictionary with the metrics."""
  eval_metrics = []
  for batches in eval_ds.as_numpy_iterator():
    inputs, outputs, programs, _ = common_utils.shard(batches)

    metrics = p_eval_step(target, inputs, outputs, programs)
    eval_metrics.append(metrics)

  eval_metrics = common_utils.get_metrics(eval_metrics)
  eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
  eval_denominator = eval_metrics_sums.pop('denominator')
  eval_summary = jax.tree_map(
      lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
      eval_metrics_sums)
  return eval_summary
Exemplo n.º 16
0
def evaluate(*, p_eval_step, target, eval_ds, num_eval_steps):
    """Evaluate the target an return a dictionary with the metrics."""
    logging.info('Gathering evaluation metrics.')
    eval_metrics = []
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    for _, eval_batch in zip(range(num_eval_steps), eval_iter):
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        eval_batch = common_utils.shard(eval_batch)
        metrics = p_eval_step(target, eval_batch)
        eval_metrics.append(metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)
    return eval_summary
Exemplo n.º 17
0
def train_for_one_epoch(
    dataset_source,
    optimizer, state,
    prng_key, pmapped_train_step,
    summary_writer
):
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key, prng_key = jax.random.split(prng_key)
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, prng_key
Exemplo n.º 18
0
def eval_on_dataset(
    model: flax.nn.Model, state: flax.nn.Collection, dataset: tf.data.Dataset,
    pmapped_eval_step: _EvalStep):
  """Evaluates the model on the whole dataset.

  Args:
    model: The model to evaluate.
    state: Current state associated with the model (contains the batch norm MA).
    dataset: Dataset on which the model should be evaluated. Should already
      being batched.
    pmapped_eval_step: A pmapped version of the `eval_step` function (see its
      documentation for more details).

  Returns:
    A dictionary containing the loss and error rate on the batch. These metrics
    are averaged over the samples.
  """
  eval_metrics = []
  total_num_samples = 0
  all_host_psum = jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')

  for eval_batch in dataset:
    # Load and shard the TF batch.
    eval_batch = load_and_shard_tf_batch(eval_batch)
    # Compute metrics and sum over all observations in the batch.
    metrics = pmapped_eval_step(model, state, eval_batch)
    eval_metrics.append(metrics)
    if 'mask' not in eval_batch:
      # Number of samples seen in num_replicas * per_replica_batch_size.
      total_num_samples += (
          eval_batch['label'].shape[0] * eval_batch['label'].shape[1] *
          jax.host_count())
    else:
      total_num_samples += all_host_psum(eval_batch['mask'])[0].sum()

  # Metrics are all the same across all replicas (since we applied psum in the
  # eval_step). The next line will fetch the metrics on one of them.
  eval_metrics = common_utils.get_metrics(eval_metrics)
  # Finally, we divide by the number of samples to get the mean error rate and
  # cross entropy.
  eval_summary = jax.tree_map(lambda x: x.sum() / total_num_samples,
                              eval_metrics)
  return eval_summary
    def test_reproduce_paper_evals(self, num_chunks):
        """Reproduce results from https://www.aclweb.org/anthology/P18-1009.pdf."""
        num_samples = self.labels.shape[0]
        chunk_size = num_samples // num_chunks
        metrics = []
        for chunk_start in range(0, num_samples, chunk_size):
            chunk_end = min(chunk_start + chunk_size, num_samples)
            labels = self.labels[chunk_start:chunk_end]
            predictions = self.predictions[chunk_start:chunk_end]
            current_metrics = ultra_fine_entity_typing_task.get_prediction_recall_metrics(
                labels, predictions)
            current_metrics = jax.tree_map(lambda x: jnp.expand_dims(x, 0),
                                           current_metrics)
            metrics.append(current_metrics)

        metrics = common_utils.get_metrics(metrics)
        metrics_sum = jax.tree_map(jnp.sum, metrics)
        processed_metrics = metric_utils.process_metrics(metrics_sum)
        self.assertAlmostEqual(processed_metrics['total_precision_value'],
                               0.481,
                               places=3)
        self.assertAlmostEqual(processed_metrics['total_recall_value'],
                               0.232,
                               places=3)
        self.assertAlmostEqual(
            processed_metrics['coarse_grained_precision_value'],
            0.603,
            places=3)
        self.assertAlmostEqual(
            processed_metrics['coarse_grained_recall_value'], 0.616, places=3)
        self.assertAlmostEqual(
            processed_metrics['fine_grained_precision_value'], 0.404, places=3)
        self.assertAlmostEqual(processed_metrics['fine_grained_recall_value'],
                               0.384,
                               places=3)
        self.assertAlmostEqual(
            processed_metrics['ultra_fine_grained_precision_value'],
            0.428,
            places=3)
        self.assertAlmostEqual(
            processed_metrics['ultra_fine_grained_recall_value'],
            0.088,
            places=3)
Exemplo n.º 20
0
    def eval_split(self, train_state, split_name):
        """Evaluation loop on the specified split.

    Args:
      train_state: TrainState; Object containing training state.
      split_name: str; Name of the data split we want to evaluate the model on.

    Returns:
      eval_summary, train_state
    """
        data_iters = self.task.dataset.data_iters[split_name]
        eval_metrics = []
        for _ in range(self.steps_per_eval):
            env_eval_batches = self.get_next_batch(data_iters)
            e_metrics = self.pmapped_eval_step(train_state, env_eval_batches)
            eval_metrics.append(e_metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_summary = pipeline_utils.compute_global_mean_metrics(eval_metrics)

        return eval_summary, eval_metrics
Exemplo n.º 21
0
def evaluate(
    eval_step_fn,
    train_state: ts.TrainState,
    model_vars: Dict[str, Any],
    eval_data: Sequence[Dict[str, Any]],
) -> Tuple[Dict[str, Any], Sequence[Tuple[Dict[str, Any], Optional[Dict[
    str, Any]]]]]:
  """Evaluate current parameters and return a dictionary with metrics.

  Args:
    eval_step_fn: partial eval step that takes in model params and inputs only
    train_state: contains model params, loss fn, grad update fn.
    model_vars: model variables that are not optimized.
    eval_data: sequence of evaluation data.

  Returns:
    Dictionary of metrics aggregated over all evaluation steps and the info
    for the very first batch (batch itself and corresponding auxiliary output).
  """

  logging.info('Performing evaluation.')
  eval_metrics = []
  eval_auxiliary = []
  for batch in eval_data:
    batch = jax.tree_map(jnp.asarray, batch)
    metrics, auxiliary_output = eval_step_fn(
        train_state,
        model_vars,
        batch,
    )
    eval_metrics.append(metrics)
    batch_auxiliary = (jax.device_get(batch), jax.device_get(auxiliary_output))
    eval_auxiliary.append(batch_auxiliary)
  eval_metrics = common_utils.get_metrics(eval_metrics)
  eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
  eval_summary = metric_utils.process_metrics(eval_metrics_sums, prefix='eval')
  return eval_summary, eval_auxiliary
Exemplo n.º 22
0
 def run_eval(eval_ds, num_eval_steps=-1):
     eval_metrics = []
     eval_iter = iter(eval_ds)
     if num_eval_steps == -1:
         num_iter = itertools.count()
     else:
         num_iter = range(num_eval_steps)
     for _, eval_batch in zip(num_iter, eval_iter):
         # pylint: disable=protected-access
         eval_batch = common_utils.shard(
             jax.tree_map(lambda x: x._numpy(), eval_batch))
         # pylint: enable=protected-access
         metrics = p_eval_step(optimizer.target, eval_batch)
         eval_metrics.append(metrics)
     eval_metrics = common_utils.get_metrics(eval_metrics)
     eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
     eval_denominator = eval_metrics_sums.pop('denominator')
     eval_summary = jax.tree_map(
         lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
         eval_metrics_sums)
     # Calculate (clipped) perplexity after averaging log-perplexities:
     eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']),
                                           a_max=1.0e4)
     return eval_summary
Exemplo n.º 23
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()

  batch_size = FLAGS.batch_size
  learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  num_eval_steps = FLAGS.num_eval_steps
  eval_freq = FLAGS.eval_frequency
  max_length = FLAGS.max_length
  random_seed = FLAGS.random_seed

  if not FLAGS.dev:
    raise app.UsageError('Please provide path to dev set.')
  if not FLAGS.train:
    raise app.UsageError('Please provide path to training set.')

  parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params')
  if jax.host_id() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  device_batch_size = batch_size // jax.device_count()

  # create the training and development dataset
  vocabs = input_pipeline.create_vocabs(FLAGS.train)
  attributes_input = [input_pipeline.CoNLLAttributes.FORM]
  attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
  train_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.train,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=max_length)

  eval_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.dev,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=max_length,
      repeat=1)
  train_iter = iter(train_ds)
  bs = device_batch_size * jax.device_count()

  rng = random.PRNGKey(random_seed)
  rng, init_rng = random.split(rng)
  input_shape = (bs, max_length)
  transformer_kwargs = {
      'vocab_size': len(vocabs['forms']),
      'output_vocab_size': len(vocabs['xpos']),
      'emb_dim': 512,
      'num_heads': 8,
      'num_layers': 6,
      'qkv_dim': 512,
      'mlp_dim': 2048,
      'max_len': max_length,
  }
  model = create_model(init_rng, tuple(input_shape), transformer_kwargs)

  optimizer = create_optimizer(model, learning_rate)
  del model  # don't keep a copy of the initial model
  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)

  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())

  metrics_all = []
  tick = time.time()
  best_dev_score = 0
  for step, batch in zip(range(num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    if (step + 1) % eval_freq == 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()
      # reset metric accumulation for next evaluation cycle.
      metrics_all = []

      eval_metrics = []
      eval_iter = iter(eval_ds)
      if num_eval_steps == -1:
        num_iter = itertools.repeat(1)
      else:
        num_iter = range(num_eval_steps)
      for _, eval_batch in zip(num_iter, eval_iter):
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = eval_batch['inputs'].shape[0]
        if cur_pred_batch_size != batch_size:
          logging.info('Uneven batch size %d.', cur_pred_batch_size)
          eval_batch = jax.tree_map(
              lambda x: pad_examples(x, batch_size), eval_batch)
        eval_batch = common_utils.shard(eval_batch)

        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      # Calculate (clipped) perplexity after averaging log-perplexities:
      eval_summary['perplexity'] = jnp.clip(
          jnp.exp(eval_summary['loss']), a_max=1.0e4)
      logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                   eval_summary['loss'], eval_summary['accuracy'])

      if best_dev_score < eval_summary['accuracy']:
        best_dev_score = eval_summary['accuracy']
        # TODO: save model.
      eval_summary['best_dev_score'] = best_dev_score
      logging.info('best development model score %.4f', best_dev_score)
      if jax.host_id() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()
Exemplo n.º 24
0
def run_train(run_configuration):
    """Runs the training workflow."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    log_dir = os.path.join(run_dir, 'train')
    checkpoint_path = run_configuration.original_checkpoint_path

    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)

    summary_freq = config.logging.summary_freq
    metrics_all = []
    tick = time.time()
    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rngs, logits, state = train_step(
            optimizer, train_inputs, dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(
                    checkpoint_dir, jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if summary_freq and step % summary_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train step: %d, loss: %.4f', step, summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = summary_freq / (tock - tick)
                examples_per_sec = denominator / (tock - tick)
                tick = tock
                summary_writer.scalar('per-second/steps', steps_per_sec, step)
                summary_writer.scalar('per-second/examples', examples_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar(key, val, step)

                adapter.write_summaries(example, logits, summary_writer, info,
                                        step, state)

                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []
Exemplo n.º 25
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = create_input_iter(local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=FLAGS.cache)
    eval_iter = create_input_iter(local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=FLAGS.cache)

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    base_learning_rate = base_learning_rate / FLAGS.loss_scaling

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=FLAGS.momentum,
                               nesterov=True).create(model)
    state = TrainState(step=0, optimizer=optimizer, model_state=model_state)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemplo n.º 26
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = "tpu_driver"
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_token = 2  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_token, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Exemplo n.º 27
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_t5_mlm",
                           model_args,
                           data_args,
                           framework="flax")

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        level=logging.INFO,
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.config_name:
        config = T5Config.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            vocab_size=len(tokenizer),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = T5Config.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
    # Since we make sure that all sequences are of the same length, no attention_mask is needed.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name],
                         return_attention_mask=False)

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
    # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
    # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
    expanded_inputs_length, targets_length = compute_input_and_target_lengths(
        inputs_length=max_seq_length,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= expanded_inputs_length:
            total_length = (total_length //
                            expanded_inputs_length) * expanded_inputs_length
        # Split by chunks of max_len.
        result = {
            k: [
                t[i:i + expanded_inputs_length]
                for i in range(0, total_length, expanded_inputs_length)
            ]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxT5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config.vocab_size = len(tokenizer)
        model = FlaxT5ForConditionalGeneration(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
        input_length=max_seq_length,
        target_length=targets_length,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    num_of_hosts = jax.process_count()
    current_host_idx = jax.process_index()

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set([
            layer[-2:] for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        ])
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in layer_norm_named_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        # Avoid using jax.numpy here in case of TPU training
        train_samples_idx = np.random.permutation(np.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            local_host_model_inputs = {
                key: np.split(model_inputs.data[key], num_of_hosts,
                              axis=0)[current_host_idx]
                for key, value in model_inputs.data.items()
            }

            # Model forward
            model_inputs = shard(local_host_model_inputs)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
                    f" {train_metric['learning_rate'].mean()})")

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                # Avoid using jax.numpy here in case of TPU training
                eval_samples_idx = np.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size,
                                                       drop_last=False)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples)

                    # Model forward
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params,
                        model_inputs.data,
                        min_device_batch=per_device_eval_batch_size)
                    eval_metrics.append(metrics)

                # get eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                # Update progress bar
                epochs.write(
                    f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
                )

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        # Avoid using jax.numpy here in case of TPU training
        eval_samples_idx = np.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size,
                                               drop_last=False)

        eval_metrics = []
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params,
                model_inputs.data,
                min_device_batch=per_device_eval_batch_size)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(),
                                    eval_metrics)

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Exemplo n.º 28
0
  Args:
    model: The model the evaluate.
    state: Model state containing state for stateful flax.nn functions, such as
      batch normalization.
    eval_dataset: Dataset to evaluate the model over.

  Returns:
  Dictionary containing the average loss and accuracy of the model on the given
  dataset.
  """
  p_eval_step = jax.pmap(_eval_step, axis_name='batch')

  batch_sizes = []
  metrics = []
  for batch in eval_dataset:
    batch_size = len(batch[LABELKEY])

    # These are required for pmap call.
    batch = _shard_batch(batch)
    batch_metrics = p_eval_step(model, state, batch)

    batch_sizes.append(batch_size)
    metrics.append(batch_metrics)

  # Note: use weighted mean, since we do mean of means with potentially
  # different batch sizes otherwise.
  batch_sizes = jnp.array(batch_sizes)
  weights = batch_sizes / jnp.sum(batch_sizes)
  eval_metrics = common_utils.get_metrics(metrics)
  return jax.tree_map(lambda x: (weights * x).sum(), eval_metrics)
Exemplo n.º 29
0
                num_eval_samples = len(tokenized_datasets["validation"])
                eval_samples_idx = jnp.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

                eval_metrics = []
                for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
                    samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
                    model_inputs = data_collator(samples, pad_to_multiple_of=16)

                    # Model forward
                    model_inputs = shard(model_inputs.data)
                    metrics = p_eval_step(state.params, model_inputs)
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
                eval_normalizer = eval_metrics.pop("normalizer")
                eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)

                # Update progress bar
                epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
Exemplo n.º 30
0
def train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps,
               num_train_steps, optimizer, state, p_eval_step, p_train_step,
               start_step, train_iter, summary_writer):
    """Training loop.

  Args:
    config: experiment config.
    dropout_rngs: float array; Jax PRNG key.
    eval_ds: tf.dataset; Evaluation dataset.
    eval_freq: int; Evaluation frequency;
    num_eval_steps: int; Number of evaluation steps.
    num_train_steps: int; Number of training steps.
    optimizer: flax optimizer.
    state: model state, e.g. batch statistics.
    p_eval_step: fn; Pmapped evaluation step function.
    p_train_step: fn; Pmapped train step function.
    start_step: int; global training step.
    train_iter: iter(tf.dataset); Training data iterator.
    summary_writer: tensorflow summary writer.

  Returns:
    optimizer, global training step
  """
    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    step = 0
    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, state, metrics, dropout_rngs = p_train_step(
            optimizer, state, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            (jax_utils.unreplicate(optimizer),
                                             jax_utils.unreplicate(state)),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('examples_per_second',
                                      steps_per_sec * config.batch_size, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'val_{key}', val, step)
                summary_writer.flush()
    return optimizer, state, step