Esempio n. 1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    batch_size = FLAGS.batch_size
    learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    num_eval_steps = FLAGS.num_eval_steps
    eval_freq = FLAGS.eval_frequency
    max_target_length = FLAGS.max_target_length
    max_eval_target_length = FLAGS.max_eval_target_length
    random_seed = FLAGS.random_seed

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
        n_devices=jax.local_device_count(),
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        dynamic_batching=True,
        max_target_length=max_target_length,
        max_eval_target_length=max_eval_target_length)
    vocab_size = info_ds['text'].encoder.vocab_size
    encoder = info_ds['text'].encoder

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_target_length)

    transformer_lm_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': 512,
        'num_heads': 8,
        'num_layers': 6,
        'qkv_dim': 512,
        'mlp_dim': 2048,
        'max_len': max(max_target_length, max_eval_target_length)
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model, cache_def = create_model(init_rng, input_shape,
                                    transformer_lm_kwargs)
    optimizer = create_optimizer(model, learning_rate)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        # Replicate optimizer.
        optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    p_pred_step = jax.pmap(predict_step, axis_name='batch')

    metrics_all = []
    tick = time.time()
    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and FLAGS.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f', step,
                         summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                train_summary_writer.scalar('steps per second', steps_per_sec,
                                            step)
                for key, val in summary.items():
                    train_summary_writer.scalar(key, val, step)
                train_summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            eval_summary['perplexity'] = jnp.clip(jnp.exp(
                eval_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('eval in step: %d, loss: %.4f', step,
                         eval_summary['loss'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(key, val, step)
                eval_summary_writer.flush()

            # Fast inference of prompt extension using trained LM.
            rng, subrng = jax.random.split(rng)
            pred_rngs = random.split(subrng, jax.local_device_count())
            prompt = jnp.array(encoder.encode(FLAGS.prompt))
            prompt = jax_utils.replicate(prompt)
            prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (1, FLAGS.max_predict_token_length)))
            predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
            predicted = tohost(predicted)
            exemplars = ''
            for n in range(predicted.shape[0]):
                exemplars += encoder.decode(predicted[n]) + '\n\n'
            if jax.host_id() == 0:
                eval_summary_writer.text('samples', exemplars, step)
                eval_summary_writer.flush()
Esempio n. 2
0
def train_and_evaluate(
  random_seed, batch_size, learning_rate, num_train_steps, num_eval_steps,
  eval_freq, max_target_length, max_eval_target_length, weight_decay, data_dir,
  model_dir, restore_checkpoints, save_checkpoints, checkpoint_freq,
  max_predict_token_length, sampling_temperature, sampling_top_k, prompt_str):
  """Executes model training and evaluation loop.
  
  Args:
    random_seed: Seed for initializing PRNG random seed.
    batch_size: Batch size for training.
    learning_rate: Learning rate for the Adam optimizer.
    num_train_steps: Number of training steps.
    num_eval_steps: Number of evaluation steps.
    eval_freq: Frequency of evaluation during training.
    max_target_length: Maximum length of training examples.
    max_eval_target_length: Maximum length of eval examples.
    weight_decay: Decay factor for AdamW-style weight decay.
    data_dir: Directory containing TFDS lm1b/subwords32k datasets.
    model_dir: Directory where to store model data.
    restore_checkpoints: Whether to restore from existing model checkpoints.
    save_checkpoints: Whether to save model checkpoints.
    checkpoint_freq: Save a checkpoint every these number of steps.
    max_predict_token_length: Maximum example text inference token length.
    sampling_temperature: Sampling temperature for language model inference.
    sampling_top_k: Top k cutoff for logit sampling.
    prompt_str: Prompt for language model sampling.
  """
  if jax.host_id() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(model_dir, 'train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(model_dir, 'eval'))

  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
      n_devices=jax.local_device_count(),
      data_dir=data_dir,
      batch_size=batch_size,
      dynamic_batching=True,
      max_target_length=max_target_length,
      max_eval_target_length=max_eval_target_length)
  vocab_size = info_ds['text'].encoder.vocab_size
  encoder = info_ds['text'].encoder

  train_iter = iter(train_ds)
  input_shape = (batch_size, max_target_length)

  transformer_lm_kwargs = {
      'vocab_size': vocab_size,
      'emb_dim': 512,
      'num_heads': 8,
      'num_layers': 6,
      'qkv_dim': 512,
      'mlp_dim': 2048,
      'max_len': max(max_target_length, max_eval_target_length)
  }

  rng = random.PRNGKey(random_seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = random.split(rng)
  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())

  model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs)
  optimizer = create_optimizer(model, learning_rate, weight_decay)
  del model  # Don't keep a copy of the initial model.
  start_step = 0
  if restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)
  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')
  p_pred_step = jax.pmap(predict_step, axis_name='batch')

  metrics_all = []
  tick = time.time()
  for step, batch in zip(range(start_step, num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    # Save a Checkpoint
    if ((step % checkpoint_freq == 0 and step > 0) or
        step == num_train_steps - 1):
      if jax.host_id() == 0 and save_checkpoints:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            model_dir, jax_utils.unreplicate(optimizer), step)

    # Periodic metric handling.
    if step % eval_freq == 0 and step > 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()
      # Reset metric accumulation for next evaluation cycle.
      metrics_all = []

      # Eval Metrics
      eval_metrics = []
      eval_iter = iter(eval_ds)
      if num_eval_steps == -1:
        num_iter = itertools.repeat(1)
      else:
        num_iter = range(num_eval_steps)
      for _, eval_batch in zip(num_iter, eval_iter):
        # pylint: disable=protected-access
        eval_batch = common_utils.shard(
            jax.tree_map(lambda x: x._numpy(), eval_batch))
        # pylint: enable=protected-access
        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)
      # Calculate (clipped) perplexity after averaging log-perplexities:
      eval_summary['perplexity'] = jnp.clip(
          jnp.exp(eval_summary['loss']), a_max=1.0e4)
      logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss'])
      if jax.host_id() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()

      # Fast inference of prompt extension using trained LM.
      rng, subrng = jax.random.split(rng)
      pred_rngs = random.split(subrng, jax.local_device_count())
      prompt = jnp.array(encoder.encode(prompt_str))
      prompt = jax_utils.replicate(prompt)
      prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
      cache = jax_utils.replicate(
          cache_def.initialize_cache((1, max_predict_token_length)))
      predicted = p_pred_step(
        prompt, optimizer.target, cache, pred_rngs, max_predict_token_length,
        sampling_temperature, sampling_top_k)
      predicted = tohost(predicted)
      exemplars = ''
      for n in range(predicted.shape[0]):
        exemplars += encoder.decode(predicted[n]) + '\n\n'
      if jax.host_id() == 0:
        eval_summary_writer.text('samples', exemplars, step)
        eval_summary_writer.flush()