예제 #1
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.makedirs(FLAGS.save_dir)

  hparam_str_dict = json.loads(FLAGS.xm_parameters)
  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in hparam_str_dict.keys()])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  predict_io_shape = (FLAGS.per_device_batch_size,
                      FLAGS.num_strings_per_task,
                      FLAGS.predict_max_characters)
  target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length)
  indices_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_strings_per_task)
  program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)  # pylint: disable=unused-variable

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']:
    spec_vocab = robust_fill_dsl.CHARACTER + '|'
    spec_id_token_table = {i+3: token
                           for i, token in enumerate(spec_vocab)}
    bos_id = 1
    eos_id = 2
    spec_id_token_table[bos_id] = robust_fill_dsl.BOS
    spec_id_token_table[eos_id] = robust_fill_dsl.EOS
    spec_token_id_table = {token: id
                           for id, token in spec_id_token_table.items()}
    spec_vocab_size = len(spec_token_id_table) + 1  # For padding.
    program_id_token_table, program_token_id_table = (
        dsl_tokens.build_token_tables())
    program_vocab_size = len(program_id_token_table) + 1  # pylint: disable=unused-variable
  elif FLAGS.dataset_type == 'scan':
    # TODO(jxihong): Scan is not handled yet.
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))
  else:
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    if FLAGS.dataset_type == 'robust_fill':
      def decode_str(s):
        """Decode string tokens."""
        return ''.join([spec_id_token_table[t_id] for t_id in s if t_id > 0])

      inps, outs = [], []
      for inp, out in zip(inputs, outputs):
        inps.append(decode_str(inp))
        outs.append(decode_str(out))
      return inps, outs

    elif FLAGS.dataset_type == 'scan':
      def decode_str(s):
        """Decode string tokens."""
        return ' '.join([spec_id_token_table[t_id] for t_id in s if t_id > 0])

      inps = [decode_str(inp) for inp in inputs]
      dummy_outs = [''] * len(inps)
      return inps, dummy_outs

    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_target(target):
    """Decode program tokens."""
    target = target[:np.argmax(target == eos_id)].astype(np.int32)

    if FLAGS.dataset_type == 'robust_fill':
      target = target[target != bos_id].tolist()
      return ''.join([spec_id_token_table[t_id] for t_id in target if t_id > 0])
    elif FLAGS.dataset_type == 'scan':
      # TODO(jxihong): Scan is not handled yet.
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))
    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_program(program):
    """Decode program tokens."""
    program = program[:np.argmax(program == eos_id) + 1].astype(np.int32)

    if FLAGS.dataset_type == 'robust_fill':
      # Returns either a Concat program object, or None.
      program = program[program != bos_id].tolist()
      try:
        return robust_fill_dsl.decode_program(program, program_id_token_table)
      except:  # pylint: disable=bare-except
        return None  # Program does not compile.
    elif FLAGS.dataset_type == 'scan':
      # Returns a string.
      program = program[jnp.logical_and(program != bos_id,
                                        program != eos_id)].tolist()
      return ' '.join(scan_vocab.decode(program, program_id_token_table))
    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_program_str(program):  # pylint: disable=unused-variable
    """Decode program tokens into a string."""
    decoded = decode_program(program)
    if FLAGS.dataset_type == 'robust_fill':
      try:
        return decoded.to_string()
      except:  # pylint: disable=bare-except
        return 'did not compile'
    else:
      assert isinstance(decoded, str), '{} should be string'.format(decoded)
      return decoded

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
  padded_shapes = {
      'inputs': io_shape[1:],
      'outputs': io_shape[1:],
      'spec_parts': target_shape[1:],
      'start_index': indices_shape[1:],
      'end_index': indices_shape[1:],
      # TODO(kshi): include programs.
      # 'program': program_shape[1:],
  }
  logging.info('padded_shapes: %s', padded_shapes)

  if FLAGS.dataset_type == 'robust_fill':
    create_dataset_fn = input_pipeline.create_robust_fill_dataset_from_tf_record
  elif FLAGS.dataset_type == 'scan':
    raise NotImplementedError()  # TODO(kshi): Implement.
    # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record
  else:
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  dataset = create_dataset_fn(
      FLAGS.dataset_filepattern, program_token_id_table, spec_token_id_table,
      max_target_length=FLAGS.max_target_length)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=padded_shapes,
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_padded_shapes = padded_shapes.copy()
  predict_padded_shapes['inputs'] = predict_io_shape[1:]
  predict_padded_shapes['outputs'] = predict_io_shape[1:]
  logging.info('predict_padded_shapes: %s', predict_padded_shapes)
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=predict_padded_shapes)
  train_ds = dataset.skip(FLAGS.num_eval_steps)
  if FLAGS.train_set_batches > 0:
    train_ds = train_ds.take(FLAGS.train_set_batches)
  train_ds = train_ds.repeat()

  test_dataset = create_dataset_fn(
      FLAGS.test_dataset_filepattern, program_token_id_table,
      spec_token_id_table, max_target_length=FLAGS.max_target_length)
  test_dataset = test_dataset.padded_batch(
      batch_size,
      padded_shapes=predict_padded_shapes,
      drop_remainder=False)
  quick_test_dataset = (test_dataset
                        .take(FLAGS.num_quick_test_steps)
                        .unbatch()
                        .padded_batch(int(np.ceil(batch_size / 10)),
                                      padded_shapes=predict_padded_shapes))
  final_test_dataset = (test_dataset
                        .take(FLAGS.num_final_test_steps)
                        .unbatch()
                        .padded_batch(int(np.ceil(batch_size / 10)),
                                      padded_shapes=predict_padded_shapes))

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  base_config = base_models.TransformerConfig(
      vocab_size=spec_vocab_size,
      output_vocab_size=spec_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      use_relative_attention=FLAGS.use_relative_attention,
      deterministic=False,
      decode=False,
      bos_token=bos_id,
      num_input_relative_position_buckets=FLAGS.num_position_buckets,
      max_input_distance=FLAGS.max_distance,
      num_output_relative_position_buckets=FLAGS.num_position_buckets,
      max_output_distance=FLAGS.max_distance,
      num_input_cross_output_relative_position_buckets=(
          FLAGS.num_position_buckets),
      max_input_cross_output_distance=FLAGS.max_distance,
      num_program_relative_position_buckets=FLAGS.num_position_buckets,
      max_program_distance=FLAGS.max_distance,
      num_program_cross_embed_relative_position_buckets=(
          FLAGS.num_position_buckets),
      max_program_cross_embed_distance=FLAGS.max_distance,
      bidirectional_program_attention=FLAGS.bidirectional_program_attention)
  train_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config,
      dataset_type=FLAGS.dataset_type)
  eval_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config.replace(deterministic=True),
      dataset_type=FLAGS.dataset_type)
  predict_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config.replace(
          shift=False, deterministic=True,
          decode=not FLAGS.slow_decode,
          max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)),
      dataset_type=FLAGS.dataset_type)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  dropout_rng = jax.random.split(rng, jax.local_device_count())
  del rng

  m = models.DecomposeAttentionTransformer(eval_config)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(target_shape, jnp.float32))

  optimizer_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if FLAGS.finetune_start_step > 0:
      logging.info('Checking that start_step (%s) == finetune_start_step (%s)',
                   start_step, FLAGS.finetune_start_step)
      assert start_step >= FLAGS.finetune_start_step
      steps_to_skip = start_step - FLAGS.finetune_start_step
    else:
      steps_to_skip = start_step

    # TODO(kshi): It is likely that this code can lead to the job stalling for
    # 10+ hours when restarting from a checkpoint that had been trained a long
    # time, possibly because dataset skipping is slow.
    logging.info('Skipping %s steps...', steps_to_skip)
    train_ds = train_ds.skip(steps_to_skip)
    dummy_p_train_step = jax.pmap(
        lambda dropout_rng: jax.random.split(dropout_rng)[1])
    for _ in range(steps_to_skip):
      dropout_rng = dummy_p_train_step(dropout_rng)
    logging.info('Finished skipping steps')
    logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  # TODO(jxihong): Implement fast decoding.
  assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

  if FLAGS.finetune_start_step <= 0:
    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
  else:
    # Constant LR for finetuning.
    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr,
        factors='constant')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          learning_rate_fn=learning_rate_fn,
          config=train_config),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(eval_step,
                        eos_token=eos_id,
                        config=eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_target_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          eos_token=eos_id,
          max_decode_len=FLAGS.max_target_length,
          config=predict_config,
          slow_decode=FLAGS.slow_decode),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------

  logging.info('Starting training!')
  metrics_all = []
  tick = time.time()
  train_iter = train_ds.as_numpy_iterator()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, targets = load_data(next(train_iter))

    optimizer, metrics, dropout_rng = p_train_step(
        optimizer, inputs, outputs, targets, dropout_rng=dropout_rng)
    metrics_all.append(metrics)
    is_last_step = step == FLAGS.num_train_steps - 1

    # Periodic metric handling.

    # Training Metrics
    if (step and step % FLAGS.log_freq == 0) or is_last_step:
      logging.info('Gathering training metrics.')
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(
          lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
          metrics_sums)
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

      if jax.host_id() == 0:
        logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
        tock = time.time()
        steps_per_sec = FLAGS.log_freq / (tock - tick)
        tick = tock
        summary_writer.scalar('train/steps per second', steps_per_sec, step)
        for key, val in summary.items():
          summary_writer.scalar('train/' + key, val, step)
        summary_writer.flush()
      # Reset metric accumulation for next evaluation cycle.
      metrics_all = []

    # Evaluation Metrics
    if (step and step % FLAGS.eval_freq == 0) or is_last_step:
      logging.info('Gathering evaluation metrics.')
      t_evaluation_start = time.time()
      eval_metrics = []
      for batches in eval_ds.as_numpy_iterator():
        inputs, outputs, targets = load_data(batches)

        metrics = p_eval_step(optimizer.target, inputs, outputs, targets)
        eval_metrics.append(metrics)

      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      if jax.host_id() == 0:
        logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                     time.time()-t_evaluation_start, step, eval_summary['loss'])
        for key, val in eval_summary.items():
          summary_writer.scalar('eval/' + key, val, step)
        summary_writer.flush()

    # Beam search metrics.
    if (step and step % FLAGS.predict_freq == 0) or is_last_step:
      logging.info('Gathering beam search metrics.')
      test_ds = final_test_dataset if is_last_step else quick_test_dataset

      for dataset, predict_or_test in [(predict_ds, 'predict'),
                                       (test_ds, 'test')]:

        for beam_size in [1, 10]:
          t_inference_start = time.time()
          total_successes = 0
          total_denominator = 0

          ios, targets_list, predictions, top_of_beams = [], [], [], []
          for batches in dataset.as_numpy_iterator():
            pred_batch = batches
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
              padded_size = int(
                  np.ceil(cur_pred_batch_size / n_devices) * n_devices)
              # pylint: disable=cell-var-from-loop
              pred_batch = jax.tree_map(
                  lambda x: pad_examples(x, padded_size), pred_batch)
            inputs, outputs, targets = load_data(pred_batch)

            cache = (p_init_cache(inputs, outputs, targets)
                     if not FLAGS.slow_decode else None)
            predicted = p_pred_step(optimizer.target, inputs, outputs, cache,
                                    beam_size)
            predicted = tohost(predicted)
            inputs, outputs, targets = map(tohost, (inputs, outputs, targets))

            for i, beams in enumerate(predicted):
              inps, outs = decode_io(inputs[i], outputs[i])
              ground_truth = decode_target(targets[i])
              beams_target = [decode_target(beam) for beam in beams]
              predicted_target = beams_target[0]
              for beam_target in beams_target:
                if beam_target == ground_truth:
                  predicted_target = beam_target
                  total_successes += 1
                  break
              total_denominator += 1

              ios.append(' ; '.join(map(str, zip(inps, outs))))
              targets_list.append(ground_truth)
              predictions.append(predicted_target)
              logging.info('')
              logging.info('ios: %s', ios[-1])
              logging.info('targets[%s]: %s', i, targets[i])
              logging.info('ground_truth: %s', ground_truth)
              logging.info('predicted beam: %s', '\n'.join(beams_target))
              logging.info('predicted_target: %s', predicted_target)
              logging.info('beams: %s', beams)

              if not ground_truth:
                logging.warn('ground_truth is empty!')

              top_of_beam = []
              for index, beam in enumerate(beams[:-5:-1]):
                top_of_beam.append('index: {}, decoded: {}, tokens: {}'.format(
                    index, decode_target(beam), beam))
              top_of_beams.append('\n\n'.join(top_of_beam))

          all_total_successes, all_total_denominator = per_host_sum_pmap(
              jax.tree_map(np.array, (total_successes, total_denominator)))

          # Record beam search results as text summaries.
          message = []
          for n in np.random.choice(np.arange(len(predictions)), 8):
            text = (f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n'
                    f'predicted: {predictions[n]}\n\n'
                    f'top of beam:\n\n{top_of_beams[n]}\n\n')
            message.append(text)

          # Write to tensorboard.
          if jax.host_id() == 0:
            accuracy = 100 * all_total_successes / all_total_denominator
            logging.info(
                '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                predict_or_test, step, beam_size,
                all_total_successes, all_total_denominator, accuracy,
                time.time() - t_inference_start)
            summary_writer.scalar(
                '{}/beam-size-{}'.format(predict_or_test, beam_size),
                accuracy, step)

            summary_writer.text('{}-samples-beam-{}'.format(predict_or_test,
                                                            beam_size),
                                '\n------\n'.join(message), step)
            summary_writer.flush()

    # Save a Checkpoint. Do this at the end of the training loop, so that if a
    # worker is descheduled during a round of prediction (which takes a while),
    # we will redo prediction upon restarting (to avoid losing data).
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)
예제 #2
0
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create eval_dir
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    rng = jax.random.PRNGKey(config.seed + 1)

    # Build input pipeline.
    rng, ds_rng = jax.random.split(rng)
    _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True)
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model_name = config.model.name
    ncsn_def = mutils.get_model(model_name).partial(config=config)
    rng, run_rng = jax.random.split(rng)
    class_conditional = "conditional" in config.training.loss.lower()
    with nn.stateful() as init_model_state:
        with nn.stochastic(run_rng):
            input_shape = tuple(eval_ds.element_spec["image"].shape[1:])
            input_list = [(input_shape, jnp.float32),
                          (input_shape[:1], jnp.int32)]
            if class_conditional:
                input_list.append(input_list[-1])
            _, initial_params = ncsn_def.init_by_shape(model_rng,
                                                       input_list,
                                                       train=True)
            ncsn = nn.Model(ncsn_def, initial_params)

    optimizer = losses.get_optimizer(config).create(ncsn)
    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    del ncsn, init_model_state  # Do not keep a copy of the initial model.

    checkpoint_dir = os.path.join(workdir, "checkpoints")
    if config.training.loss.lower() == "ddpm":
        # Use the score matching loss with DDPM-type perturbation.
        ddpm_params = mutils.get_ddpm_params()
        eval_step = functools.partial(losses.ddpm_loss,
                                      ddpm_params=ddpm_params,
                                      train=False)
    else:
        # Use the score matching loss with NCSN-type perturbation.
        sigmas = mutils.get_sigmas(config)
        continuous = "continuous" in config.training.loss.lower()
        eval_step = functools.partial(
            losses.ncsn_loss,
            sigmas=sigmas,
            continuous=continuous,
            class_conditional=class_conditional,
            train=False,
            anneal_power=config.training.anneal_power)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    rng = jax.random.fold_in(rng, jax.host_id())

    # A data class for checkpointing.
    @flax.struct.dataclass
    class EvalMeta:
        ckpt_id: int
        round_id: int
        rng: Any

    # Add one additional round to get the exact number of samples as required.
    num_rounds = config.eval.num_samples // config.eval.batch_size + 1

    eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng)
    eval_meta = checkpoints.restore_checkpoint(eval_dir,
                                               eval_meta,
                                               step=None,
                                               prefix=f"meta_{jax.host_id()}_")

    if eval_meta.round_id < num_rounds - 1:
        begin_ckpt = eval_meta.ckpt_id
        begin_round = eval_meta.round_id + 1
    else:
        begin_ckpt = eval_meta.ckpt_id + 1
        begin_round = 0

    rng = eval_meta.rng
    # Use inceptionV3 for images with higher resolution
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    logging.info("begin checkpoint: %d", begin_ckpt)
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "ckpt-{}.flax".format(ckpt))

        # Wait if the target checkpoint hasn't been produced yet.
        waiting_message_printed = False
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed and jax.host_id() == 0:
                logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt)
                waiting_message_printed = True
            time.sleep(10)

        # In case the file was just written and not ready to read from yet.
        try:
            state = utils.load_state_dict(ckpt_filename, state)
        except:
            time.sleep(60)
            try:
                state = utils.load_state_dict(ckpt_filename, state)
            except:
                time.sleep(120)
                state = utils.load_state_dict(ckpt_filename, state)

        pstate = flax.jax_utils.replicate(state)
        eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types

        # Compute the loss function on the full evaluation dataset.
        all_losses = []
        for i, batch in enumerate(eval_iter):
            rng, *next_rng = jax.random.split(rng,
                                              num=jax.local_device_count() + 1)
            next_rng = jnp.asarray(next_rng)
            eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
            eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch)
            eval_loss = flax.jax_utils.unreplicate(eval_loss)
            all_losses.append(eval_loss)
            if (i + 1) % 1000 == 0 and jax.host_id() == 0:
                logging.info("Finished %dth step loss evaluation", i + 1)

        all_losses = jnp.asarray(all_losses)

        state = jax.device_put(state)
        # Sampling and computing statistics for Inception scores, FIDs, and KIDs.
        # Designed to be pre-emption safe. Automatically resumes when interrupted.
        for r in range(begin_round, num_rounds):
            if jax.host_id() == 0:
                logging.info("sampling -- ckpt: %d, round: %d", ckpt, r)
            rng, sample_rng = jax.random.split(rng)
            init_shape = tuple(eval_ds.element_spec["image"].shape)

            this_sample_dir = os.path.join(
                eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
            tf.io.gfile.makedirs(this_sample_dir)
            samples = sampling.get_samples(sample_rng,
                                           config,
                                           state,
                                           init_shape,
                                           scaler,
                                           inverse_scaler,
                                           class_conditional=class_conditional)
            samples = samples[-1]
            samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
            samples = samples.reshape(
                (-1, config.data.image_size, config.data.image_size, 3))
            with tf.io.gfile.GFile(
                    os.path.join(this_sample_dir, f"samples_{r}.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer, samples=samples)
                fout.write(io_buffer.getvalue())

            gc.collect()
            latents = evaluation.run_inception_distributed(
                samples, inception_model, inceptionv3=inceptionv3)
            gc.collect()
            with tf.io.gfile.GFile(
                    os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    pool_3=latents["pool_3"],
                                    logits=latents["logits"])
                fout.write(io_buffer.getvalue())

            eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng)
            # Save an intermediate checkpoint directly if not the last round.
            # Otherwise save eval_meta after computing the Inception scores and FIDs
            if r < num_rounds - 1:
                checkpoints.save_checkpoint(eval_dir,
                                            eval_meta,
                                            step=ckpt * num_rounds + r,
                                            keep=1,
                                            prefix=f"meta_{jax.host_id()}_")

        # Compute inception scores, FIDs and KIDs.
        if jax.host_id() == 0:
            # Load all statistics that have been previously computed and saved.
            all_logits = []
            all_pools = []
            for host in range(jax.host_count()):
                this_sample_dir = os.path.join(eval_dir,
                                               f"ckpt_{ckpt}_host_{host}")

                stats = tf.io.gfile.glob(
                    os.path.join(this_sample_dir, "statistics_*.npz"))
                wait_message = False
                while len(stats) < num_rounds:
                    if not wait_message:
                        logging.warn("Waiting for statistics on host %d", host)
                        wait_message = True
                    stats = tf.io.gfile.glob(
                        os.path.join(this_sample_dir, "statistics_*.npz"))
                    time.sleep(1)

                for stat_file in stats:
                    with tf.io.gfile.GFile(stat_file, "rb") as fin:
                        stat = np.load(fin)
                        if not inceptionv3:
                            all_logits.append(stat["logits"])
                        all_pools.append(stat["pool_3"])

            if not inceptionv3:
                all_logits = np.concatenate(all_logits,
                                            axis=0)[:config.eval.num_samples]
            all_pools = np.concatenate(all_pools,
                                       axis=0)[:config.eval.num_samples]

            # Load pre-computed dataset statistics.
            data_stats = evaluation.load_dataset_stats(config)
            data_pools = data_stats["pool_3"]

            if hasattr(config.eval, "num_partitions"):
                # Divide samples into several partitions and compute FID/KID/IS on them.
                assert not inceptionv3
                fids = []
                kids = []
                inception_scores = []
                partition_size = config.eval.num_samples // config.eval.num_partitions
                tf_data_pools = tf.convert_to_tensor(data_pools)
                for i in range(config.eval.num_partitions):
                    this_pools = all_pools[i * partition_size:(i + 1) *
                                           partition_size]
                    this_logits = all_logits[i * partition_size:(i + 1) *
                                             partition_size]
                    inception_scores.append(
                        tfgan.eval.classifier_score_from_logits(this_logits))
                    fids.append(
                        tfgan.eval.
                        frechet_classifier_distance_from_activations(
                            data_pools, this_pools))
                    this_pools = tf.convert_to_tensor(this_pools)
                    kids.append(
                        tfgan.eval.kernel_classifier_distance_from_activations(
                            tf_data_pools, this_pools).numpy())

                fids = np.asarray(fids)
                inception_scores = np.asarray(inception_scores)
                kids = np.asarray(kids)
                with tf.io.gfile.GFile(
                        os.path.join(eval_dir, f"report_all_{ckpt}.npz"),
                        "wb") as f:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        all_losses=all_losses,
                                        mean_loss=all_losses.mean(),
                                        ISs=inception_scores,
                                        fids=fids,
                                        kids=kids)
                    f.write(io_buffer.getvalue())

            else:
                # Compute FID/KID/IS on all samples together.
                if not inceptionv3:
                    inception_score = tfgan.eval.classifier_score_from_logits(
                        all_logits)
                else:
                    inception_score = -1

                fid = tfgan.eval.frechet_classifier_distance_from_activations(
                    data_pools, all_pools)
                # Hack to get tfgan KID work for eager execution.
                tf_data_pools = tf.convert_to_tensor(data_pools)
                tf_all_pools = tf.convert_to_tensor(all_pools)
                kid = tfgan.eval.kernel_classifier_distance_from_activations(
                    tf_data_pools, tf_all_pools).numpy()
                del tf_data_pools, tf_all_pools

                logging.info(
                    "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e",
                    ckpt, all_losses.mean(), inception_score, fid, kid)

                with tf.io.gfile.GFile(
                        os.path.join(eval_dir, f"report_{ckpt}.npz"),
                        "wb") as f:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        all_losses=all_losses,
                                        mean_loss=all_losses.mean(),
                                        IS=inception_score,
                                        fid=fid,
                                        kid=kid)
                    f.write(io_buffer.getvalue())
        else:
            # For host_id() != 0.
            # Use file existence to emulate synchronization across hosts.
            if hasattr(config.eval, "num_partitions"):
                assert not inceptionv3
                while not tf.io.gfile.exists(
                        os.path.join(eval_dir, f"report_all_{ckpt}.npz")):
                    time.sleep(1.)

            else:
                while not tf.io.gfile.exists(
                        os.path.join(eval_dir, f"report_{ckpt}.npz")):
                    time.sleep(1.)

        # Save eval_meta after computing IS/KID/FID to mark the end of evaluation
        # for this checkpoint.
        checkpoints.save_checkpoint(eval_dir,
                                    eval_meta,
                                    step=ckpt * num_rounds + r,
                                    keep=1,
                                    prefix=f"meta_{jax.host_id()}_")

        begin_round = 0

    # Remove all meta files after finishing evaluation.
    meta_files = tf.io.gfile.glob(
        os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
    for file in meta_files:
        tf.io.gfile.remove(file)
예제 #3
0
파일: train.py 프로젝트: rishirelan/flax
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    def encode_strings(strs, max_len):
        tokenized_batch = np.zeros((len(strs), max_len), np.int32)
        for i, s in enumerate(strs):
            toks = encoder.tokenize(s).numpy()
            # Remove EOS token in prompt.
            tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1]
        return tokenized_batch

    tokenized_prompts = encode_strings([config.prompts],
                                       config.max_predict_length)

    logging.info("Initializing model, optimizer, and step functions.")
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    rng, inference_rng = random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.TransformerLM(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step fn.
    p_train_step = jax.pmap(functools.partial(
        train_step, config=train_config, learning_rate_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name="batch")

    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          temperature=config.sampling_temperature,
                          top_k=config.sampling_top_k),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]),
                                                     a_max=1.0e4)
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    # (clipped) perplexity after averaging log-perplexitie
                    eval_results["perplexity"] = jnp.clip(jnp.exp(
                        eval_results["loss"]),
                                                          a_max=1.0e4)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("generate_text"):
                    exemplars = generate_prediction(
                        p_pred_step=p_pred_step,
                        target=optimizer.target,
                        tokenized_prompts=tokenized_prompts,
                        eos_id=eos_id,
                        inference_rng=inference_rng,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.host_id(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
예제 #4
0
def save_checkpoint(state):
    if jax.host_id() == 0:
        # get train state from the first replica
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        step = int(state.step)
        checkpoints.save_checkpoint(FLAGS.model_dir, state, step, keep=3)
예제 #5
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        use_relative_attention=FLAGS.use_relative_attention,
        num_relative_position_buckets=FLAGS.num_relative_position_buckets,
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          decode=True)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.ProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_lib.train_step,
        learning_rate_fn=learning_rate_fn,
        config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_lib.initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(train_lib.predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, 5, 6))

    # Main Train Loop
    # ---------------------------------------------------------------------------
    train_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, train_rngs = p_train_step(optimizer,
                                                      inputs,
                                                      outputs,
                                                      programs,
                                                      train_rng=train_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f', step,
                         summary['loss'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions = [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    # pylint: disable=cell-var-from-loop
                    pred_batch = jax.tree_map(
                        lambda x: train_lib.pad_examples(x, padded_size),
                        pred_batch)
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache = p_init_cache(inputs, outputs, programs)
                predicted = p_pred_step(optimizer.target, inputs, outputs,
                                        cache, eos_token, programs.shape[-1],
                                        beam_size)
                predicted = train_lib.tohost(predicted)
                inputs, outputs, programs = map(train_lib.tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_score = train_lib.eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')

            all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()
예제 #6
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
예제 #7
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, *rays), axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         init_lr=FLAGS.lr,
                                         decay_steps=FLAGS.lr_decay * 1000,
                                         decay_rate=0.1)
    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats[0])
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            summary_writer.scalar("train_loss", stats[0].loss[0], step)
            summary_writer.scalar("train_psnr", stats[0].psnr[0], step)
            if len(stats) > 1:
                summary_writer.scalar("train_loss_coarse", stats[1].loss[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0],
                                      step)
            avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
            avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
            stats_trace = []
            summary_writer.scalar("train_avg_loss", avg_loss, step)
            summary_writer.scalar("train_avg_psnr", avg_psnr, step)
            summary_writer.scalar("learning_rate", lr, step)
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"i_loss={stats[0].loss[0]:0.5f} | " +
                  f"avg_loss={avg_loss:0.5f}, " +
                  f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
예제 #8
0
파일: train.py 프로젝트: skye/flax
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_token = 2  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': FLAGS.emb_dim,
        'num_heads': FLAGS.num_heads,
        'num_layers': FLAGS.num_layers,
        'qkv_dim': FLAGS.qkv_dim,
        'mlp_dim': FLAGS.mlp_dim,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          use_bfloat16=FLAGS.use_bfloat16,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0 and step > 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_token, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
예제 #9
0
def main(argv):
  global CFG
  CFG = FLAGS.config

  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
  _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

  # Use hardware RNG for bernoulli randoms in dropout mask creation.
  if CFG.hardware_rng:
    models.set_hardware_bernoulli()

  if 'module_import' in CFG and CFG.module_import:
    for module in CFG.module_import:
      importlib.import_module(module)

  if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
    t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

  num_partitions = CFG.num_partitions
  topology = train_lib.compute_multihost_topology(num_partitions)
  batch_size = CFG.batch_size
  eval_batch_size = CFG.eval_batch_size
  per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
  if batch_size % topology.num_replicas:
    raise ValueError('Batch size must be divisible by the number of replicas.')

  steps_per_epoch = CFG.steps_per_epoch
  logging.info('steps per epoch: %d', steps_per_epoch)

  broadcast = functools.partial(
      train_lib.broadcast,
      num_replicas=topology.per_replica_set_num_replicas,
      num_partitions=topology.per_host_num_partitions,
      devices=topology.this_host_device_assignment)

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    tf.io.gfile.copy(FLAGS['config'].config_filename,
                     os.path.join(FLAGS.model_dir, 'config.py'),
                     overwrite=True)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'eval'))
  else:
    train_summary_writer = None
    eval_summary_writer = None

  # Write summaries in background thread to avoid blocking on device sync
  if CFG.infeed:
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

  (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
      CFG, topology.num_replica_sets, topology.replica_set_id,
      topology.per_replica_set_host_id)

  vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
  encoder = vocab.tf_tokenizer
  eos_id = vocab.tokenizer.eos_id()

  def decode_tokens(toks,
                    eos_id = eos_id,
                    max_id = 32000):
    """Decode tokens back to unicode."""
    del eos_id
    # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
    # is the best decoding function or just switch to using tf_decode.
    # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    valid_toks = toks.astype(np.int32)
    valid_toks[valid_toks >= max_id] = 3
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  train_config, eval_config, predict_config = get_configs(CFG)

  rng = random.PRNGKey(CFG.random_seed)
  rng, init_rng = random.split(rng)
  # This is used for infeed conversion from feature dict <--> tuple
  train_keys = [
      'inputs', 'targets', 'inputs_position', 'targets_position',
      'inputs_segmentation', 'targets_segmentation'
  ]
  device_train_input_shape = tuple([
      (batch_size // topology.num_replicas,
       CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
      for k in train_keys
  ])

  learning_rate_fn = train_lib.create_learning_rate_scheduler(
      factors=CFG.schedule,
      base_learning_rate=CFG.learning_rate,
      warmup_steps=CFG.warmup_steps)

  # First, we only abstractly initialize the optimizer and model parameters,
  # since the parameters may not even fit in device memory!
  # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
  # in get_initial_params without causing pytree incompatibility
  optimizer_def = optim.Adafactor(
      CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset)
  initialize_params_fn = functools.partial(
      get_initial_params,
      config=CFG,
      transformer_config=eval_config,
      optimizer_def=optimizer_def)
  optimizer = jax.eval_shape(initialize_params_fn, init_rng)
  # tuple-like pytree leaves for global_arg_shapes
  optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                  optimizer)

  # Build parameter partition annotations for preserving partitions from train
  # to eval.
  if num_partitions > 1:
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(num_partitions, optimizer.state_dict()))
    per_host_optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(topology.per_host_num_partitions,
                                  optimizer.state_dict()))

  # Restore unreplicated optimizer + model state from last checkpoint.
  # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
  existing_checkpoint_found = False
  if CFG.restore_checkpoints:
    existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir)
    optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

  # Import a pretrained-T5 checkpoint only if we didn't import a local
  # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
  # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
  if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
    optimizer = checkpoint_importer.restore_from_t5_checkpoint(
        optimizer, CFG.restore_t5_checkpoint)

  if CFG.restore_t5_checkpoint or existing_checkpoint_found:
    if num_partitions > 1:
      # Until checkpoint/restore is sharded, the restored checkpoint is global
      # and we need to slice each sharded parameter into the chunk containing
      # only the partitions that are present on this host.
      def per_host_chunk(x, spec):
        if spec is None or spec is x:  # unsharded or not a parameter
          return x
        if spec[0] == 1:
          dim_size = x.shape[1]
        elif spec[1] == 1:
          dim_size = x.shape[0]
        else:
          raise NotImplementedError()
        chunk_size = (
            dim_size * topology.per_host_num_partitions // num_partitions)
        lower = topology.per_replica_set_host_id * chunk_size
        upper = (topology.per_replica_set_host_id + 1) * chunk_size
        if spec[0] == 1:
          return x[:, lower:upper]
        else:
          return x[lower:upper]

      optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                    optimizer_partitions)
  else:
    # If pretraining and no checkpoint imported, we jit the (sharded-) init
    # function to minimize fragmentation. We use the same pmap(sharded_jit)
    # setup as the training step/loop to initialize everything "in-place" and
    # avoid communication or OOM.
    if num_partitions > 1:
      initialize_params_fn = sharded_jit(
          initialize_params_fn,
          in_parts=None,
          local_in_parts=None,
          out_parts=optimizer_partitions,
          local_out_parts=per_host_optimizer_partitions,
          # devices=one_replica_device_assignment,
      )
      initialize_params_fn = jax.pmap(
          initialize_params_fn,
          'batch',
          in_axes=0,
          axis_size=topology.num_replicas,
          devices=topology.device_assignment)
      init_rng = broadcast(init_rng)
      optimizer = initialize_params_fn(init_rng)
      # We maintain the optimizer in unbroadcasted form (i.e. with no leading
      # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
      # out_axes=None.
      optimizer = train_lib.unbroadcast(optimizer)
    else:
      optimizer = jax.jit(initialize_params_fn)(init_rng)

  # ---------------------------------------------------------------------------
  # Compile multidevice versions of train/eval/predict step and cache init fn.
  # ---------------------------------------------------------------------------

  # We can use either a single train-step for a host training loop:

  # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
  #  --> new_optimizer, metrics, new_dropout_rng
  def p_train_step(optimizer, batch,
                   prev_metrics,
                   dropout_rng):
    return train_lib.train_step(
        optimizer,
        batch,
        prev_metrics,
        dropout_rng,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss,
        use_bfloat16=CFG.use_bfloat16)

  if num_partitions > 1:
    p_train_step = sharded_jit(
        p_train_step,
        in_parts=(optimizer_partitions, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None),
        out_parts=(optimizer_partitions, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None))
  # TODO(levskaya): the in_axes spec below might be wrong, double-check.
  p_train_step = jax.pmap(
      p_train_step,
      axis_name='batch',
      in_axes=(None, 0, 0, 0),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # OR, we use an on-device loop that feeds the training step via infeed queue.
  def device_train_loop_cond(
      args
  ):
    """Stopping criterion for on-device loop."""
    _, _, _, _, step, epoch = args
    return step // steps_per_epoch == epoch

  def device_train_loop_body(
      args
  ):
    """On-device loop body."""
    optimizer, dropout_rngs, metrics, token, step, epoch = args
    # Ordering input data from infeed requires threading a symbolic token
    # through the computation.
    input_data, token = lax.infeed(
        token,
        shape=tuple(
            [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape]))
    # Rebuild input dict from infeed data tuple.
    batch = {k: v for k, v in zip(train_keys, input_data)}
    # Run the train_step function and return the loop state.
    optimizer, metrics, dropout_rngs = train_lib.train_step(
        optimizer,
        batch,
        metrics,
        dropout_rngs,
        train_config,
        learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss)
    step += 1
    return optimizer, dropout_rngs, metrics, token, step, epoch

  def device_train_loop(optimizer, dropout_rngs,
                        metrics, step,
                        epoch):
    # Create symbolic token for threading infeed data.
    token = lax.create_token(step)
    # Run on-device loop.
    optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
        device_train_loop_cond, device_train_loop_body,
        (optimizer, dropout_rngs, metrics, token, step, epoch))
    return optimizer, dropout_rngs, metrics, step

  if num_partitions > 1:
    device_train_loop = sharded_jit(
        device_train_loop,
        in_parts=(optimizer_partitions, None, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None, None),
        out_parts=(optimizer_partitions, None, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None, None))
  p_train_epoch = jax.pmap(
      device_train_loop,
      axis_name='batch',
      in_axes=(None, 0, 0, None, None),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Reduction psum for metric data.

  def p_allreduce_metrics(x):
    return lax.psum(x, axis_name='batch')

  if num_partitions > 1:
    p_allreduce_metrics = sharded_jit(
        p_allreduce_metrics,
        in_parts=None,
        local_in_parts=None,
        out_parts=None,
        local_out_parts=None,
        num_partitions=num_partitions,
        local_num_partitions=topology.per_host_num_partitions)
  p_allreduce_metrics = jax.pmap(
      p_allreduce_metrics,
      axis_name='batch',
      global_arg_shapes=None,
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)

  # Training evaluation computation.

  # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
  def p_eval_step(params, batch):
    return train_lib.eval_step(
        params, batch, config=eval_config, label_smoothing=CFG.label_smoothing)

  if num_partitions > 1:
    p_eval_step = sharded_jit(
        p_eval_step,
        in_parts=(optimizer_partitions.target, None),
        local_in_parts=(per_host_optimizer_partitions.target, None),
        out_parts=None,
        local_out_parts=None)
  p_eval_step = jax.pmap(
      p_eval_step,
      axis_name='batch',
      in_axes=(None, 0),
      global_arg_shapes=(optimizer_shapes.target, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Fast autoregressive decoding loop.
  # For inference and model evaluation.

  # predict_step(inputs, params,
  #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
  def p_pred_step(inputs, params):
    return train_lib.predict_step(inputs, params, eos_id,
                                  CFG.max_eval_target_length, predict_config,
                                  CFG.beam_size)

  if num_partitions > 1:
    p_pred_step = sharded_jit(
        p_pred_step,
        in_parts=(None, optimizer_partitions.target),
        local_in_parts=(None, per_host_optimizer_partitions.target),
        out_parts=None,
        local_out_parts=None)
  p_pred_step = jax.pmap(
      p_pred_step,
      axis_name='batch',
      in_axes=(0, None),
      global_arg_shapes=(None, optimizer_shapes.target),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # ---------------------------------------------------------------------------
  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  # There should be a unique dropout key for each replica represented on this
  # host, but the key should be the same for the same replica on other hosts.
  # Again, this is what the replica set abstraction is for.
  dropout_rngs = random.split(
      random.fold_in(rng, topology.replica_set_id),
      topology.per_replica_set_num_replicas)
  # restore step from last checkpoint
  host_step = int(optimizer.state.step)
  empty_metrics = broadcast({
      'loss': 0.0,
      'accuracy': 0.0,
      'learning_rate': 0.0,
      'denominator': 0.0
  })
  if CFG.infeed:
    # TODO(jekbradbury): support something like this for the Python-loop case
    logging.info('Precompiling training loop and moving optimizer to device.')
    optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                             empty_metrics,
                                             jnp.array(0, dtype=jnp.int32), 1)
    optimizer = train_lib.unbroadcast(optimizer)
    metrics['loss'].block_until_ready()

  logging.info('Starting training loop.')

  local_devices = jax.local_devices()
  device_step = broadcast(host_step)
  first_epoch = host_step // steps_per_epoch

  # Main Loop over "epochs".
  train_iter = train_ds.as_numpy_iterator()
  for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
    metrics = empty_metrics

    # NOTE: 'optimizer' is unbroadcast by construction at initialization or
    # when loading a checkpoint. It is maintained in 'unbroadcast' state to
    # enable the XLA cross-replica sharding optimization.  The broadcasting is
    # handled automatically by the pmap'd functions that use it.

    # Gather all task evaluation metrics.
    logging.info('Evaluating tasks.')
    if epoch == first_epoch + 1:
      train_lib.sync_devices()
    for task in eval_cache.tasks:
      logging.info('Evaluating task %s', task.name)
      all_predicted, all_bs = [], []
      for pred_batch in eval_cache.preprocessed_examples[task.name]:
        # Handle final odd-sized batch by padding instead of dropping it.
        input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
            pred_batch['inputs'], per_replica_set_eval_batch_size)
        all_bs.append(unpadded_batch_size)
        # Split batch dimensions for pmap.
        input_batch = jax.tree_map(
            lambda x: x.reshape(
                (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
            input_batch)
        # Run fast inference on batch.
        all_predicted.append(p_pred_step(input_batch, optimizer.target))

      # Pad out the number of batches so each host has the same number.
      max_host_batch_number = np.max(
          eval_cache.preprocessed_batch_sizes[task.name])
      batch_shortfall = max_host_batch_number - len(all_predicted)
      if batch_shortfall > 0:
        # TODO(levskaya): Fix for case of entirely empty all_predicted.
        # To make sure the cross-host barriers work, we run the program the same
        # number of times on all hosts. The results of this call is ignored, and
        # the predictions are populated with zeros instead.
        p_pred_step(input_batch, optimizer.target)  # Dummy call.
        all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                             batch_shortfall)
        all_bs.extend([0] * batch_shortfall)
      all_predicted = jnp.concatenate(all_predicted)
      all_bs = jnp.array(all_bs)

      # Collect all batches from across hosts and reverse sharding.
      all_predicted = train_lib.host_allgather(
          all_predicted, topology.num_replica_sets, topology.replica_set_id,
          topology.per_replica_set_host_id == 0)
      seqlength = all_predicted.shape[-1]
      total_examples = np.sum(
          train_lib.host_allgather(all_bs, topology.num_replica_sets,
                                   topology.replica_set_id,
                                   topology.per_replica_set_host_id == 0))
      del all_bs
      assert total_examples == len(eval_cache.examples[task.name]), (
          'Total number of batches incorrect for task %s.' % task.name)
      # De-shard the collected predicted tokens and remove padding.
      all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
          -1, seqlength)[:total_examples]

      # We now run the post-processing and metric-fns on a single host.
      if jax.host_id() == 0:
        assert eval_summary_writer
        raw_predictions = []
        for tokens in all_predicted:
          raw_predictions.append(decode_tokens(tokens))

        # post-process predictions for metric fns
        predictions = [
            task.postprocess_fn(p, example=ex)
            for p, ex in zip(raw_predictions, eval_cache.examples[task.name])
        ]

        for metric_fn in task.metric_fns:
          scores = metric_fn(eval_cache.targets[task.name], predictions)
          for metric_name, metric_value in scores.items():
            tag = f'eval/{task.name}/{metric_name}'
            eval_summary_writer.scalar(tag, metric_value, host_step)
            logging.info('EVAL %s at step %d: %.3f', tag, host_step,
                         metric_value)
          eval_summary_writer.flush()

        # Save text samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
          tgt_txt = tf.compat.as_text(
              eval_cache.examples[task.name][n]['targets_plaintext'])
          pred_txt = raw_predictions[n]
          exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                        f'target: {tgt_txt}\n\n'
                        f'prediction: {pred_txt}\n\n')
        eval_summary_writer.text(f'{task.name} samples', exemplars, host_step)
        eval_summary_writer.flush()

    # Take an Xprof trace after the first loop has compiled everything.
    if epoch == first_epoch + 1:
      train_lib.sync_devices()

    # For on-device loop, we launch the computation before feeding data.
    logging.info('BEGIN Train loop.')
    if CFG.infeed:
      optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
          optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step),
          epoch)
      optimizer = train_lib.unbroadcast(optimizer)

    # Epoch loop.
    while int(host_step // steps_per_epoch) == epoch:
      batch = next(train_iter)
      batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch)
      # Feed the on-device training loop.
      if CFG.infeed:
        for i, device in enumerate(local_devices):
          # When using infeed to provide data to the computation, we're on our
          # own for feeding the right values to the right devices. Each device
          # should get the minibatch corresponding to its replica, a slice of
          # the larger batch corresponding to the host's replica set.
          if device.platform == 'tpu':
            device_coords = (*device.coords, device.id % 2)
          else:
            device_coords = (device.host_id, i)
          per_replica_set_device_coords = tuple(
              dc % prsm
              for dc, prsm in zip(device_coords, topology.per_replica_set_mesh))
          per_replica_set_replica_coords = tuple(
              prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords,
                                                 topology.per_replica_mesh))
          per_replica_set_replica_id = 0
          for prsm, prm, prsrc in zip(topology.per_replica_set_mesh,
                                      topology.per_replica_mesh,
                                      per_replica_set_replica_coords):
            per_replica_set_replica_id = (
                per_replica_set_replica_id * prsm // prm + prsrc)
          input_tuple = tuple(
              [batch[k][per_replica_set_replica_id] for k in train_keys])
          # Safety check: infeed does not check shape or types but requires
          # them to agree with on-device spec, otherwise the queue and program
          # stalls.
          tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
          tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
          assert tuple_shapes == device_train_input_shape, (
              'infeed shape error %s != %s' %
              (tuple_shapes, device_train_input_shape))
          assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
              ('infeed dtype error %s not all of type %s' % (
                  tuple_dtypes, jnp.int32))
          infeed_pool.submit(
              functools.partial(device.transfer_to_infeed, input_tuple))
      # Host training loop.
      else:
        optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch,
                                                        metrics, dropout_rngs)
        optimizer = train_lib.unbroadcast(optimizer)
      host_step += 1
    logging.info('END Train loop.')

    # Maybe save a checkpoint on one host.
    if (CFG.save_checkpoints and
        epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and
        jax.host_id() == 0):
      checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

    # Gather training metrics.
    metrics = p_allreduce_metrics(metrics)
    metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
    denominator = metrics.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
    logging.info('train in step: %s, %s', host_step, summary)
    if jax.host_id() == 0:
      assert train_summary_writer
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, host_step)
      train_summary_writer.flush()

    # Gather training evaluation metrics.
    logging.info('Gathering training evaluation metrics.')
    eval_metrics = []
    eval_iter = eval_ds.as_numpy_iterator()
    for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
          eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    # average metrics across devices
    eval_metrics = p_allreduce_metrics(eval_metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    # average metrics across steps
    eval_metrics = jax.tree_map(np.sum, eval_metrics)
    eval_denominator = eval_metrics.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics)
    logging.info('eval in step: %s, %s', host_step, eval_summary)
    if jax.host_id() == 0:
      assert eval_summary_writer
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, host_step)
      eval_summary_writer.flush()

  # Wait until computations are done before exiting
  logging.info('Finished.')
  train_lib.sync_devices()
  # Shut down the infeed threadpool.
  if CFG.infeed:
    infeed_pool.shutdown()
예제 #10
0
 def test_save_restore_checkpoints(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object0)
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 1,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 2,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 4,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 step=3,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     with self.assertRaises(ValueError):
         checkpoints.restore_checkpoint(tmp_dir,
                                        test_object0,
                                        step=5,
                                        prefix='test_')
예제 #11
0
 def test_save_restore_checkpoints(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object0)
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 1,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 2,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 4,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 step=3,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     # Restore a specific path.
     new_object = checkpoints.restore_checkpoint(
         os.path.join(tmp_dir, 'test_3'), test_object0)
     jtu.check_eq(new_object, test_object2)
     # If a specific path is specified, but it does not exist, the same behavior
     # as when a directory is empty should apply: the target is returned
     # unchanged.
     new_object = checkpoints.restore_checkpoint(
         os.path.join(tmp_dir, 'test_not_there'), test_object0)
     jtu.check_eq(new_object, test_object0)
     with self.assertRaises(ValueError):
         checkpoints.restore_checkpoint(tmp_dir,
                                        test_object0,
                                        step=5,
                                        prefix='test_')
예제 #12
0
def save_checkpoint(optimizer, ema):
  # get train state from the first replica
  optimizer, ema = jax.device_get(
      jax.tree_map(lambda x: x[0], (optimizer, ema)))
  step = int(optimizer.state.step)
  checkpoints.save_checkpoint(FLAGS.model_dir, (optimizer, ema), step, keep=3)
예제 #13
0
def train(
    state: TrainState,
    *,
    train_iter: T.Iterator[PT.MiniBatch],
    train_step: TrainStep,
    n_steps: int,
    report_freq: int,
    reporter: Reporter,
    val_iter: T.Optional[T.Iterator[PT.MiniBatch]] = None,
    val_step: T.Optional[ValStep] = None,
    val_freq: T.Optional[int] = None,
    val_metrics: T.Optional[utils.Metrics] = None,
    distributed: bool = False,
    save_ckpts: bool = True,
    ckpt_metric: str = 'loss',
    ckpt_name: str = 'model',
    extra_report_fn: T.Optional[T.Callable[[TrainState, PT.MiniBatch, int],
                                           None]] = None,
    start_step: int = 0,
) -> TrainState:

    assert 'time/step' in reporter.train_names
    if val_step is not None:
        assert val_iter is not None
        assert val_metrics is not None
        assert val_freq is not None and val_freq % report_freq == 0
        assert 'time' in reporter.val_names

    iter_slice = itertools.islice(train_iter, 0, n_steps - start_step)
    train_iter = iter(
        tqdm(iter_slice, total=n_steps - start_step, desc='Training'))

    if distributed:
        state = flax.jax_utils.replicate(state)
        if hasattr(state, 'rngs'):
            pfold_in = partial(jax.pmap(jax.random.fold_in),
                               data=jnp.arange(jax.device_count()))
            state = state.replace(rngs=jax.tree_map(pfold_in, state.rngs))

    with reporter as rep, redirect_print_to_tqdm():
        cur_best = -1
        start_time = time.perf_counter()
        for i, batch in enumerate(train_iter, start=start_step):
            state = train_step(state, batch)

            if i % report_freq == 0 or i == n_steps - 1:
                time_per_step = (time.perf_counter() -
                                 start_time) / report_freq
                train_dict = {'time/step': time_per_step}
                train_dict.update(state.metrics.items())
                state = state.replace(metrics=state.metrics.reset())

                val_dict = {}
                if val_freq is not None and (i % val_freq == 0
                                             or i == n_steps - 1):
                    assert val_step is not None
                    assert val_metrics is not None
                    assert val_iter is not None

                    start_time = time.perf_counter()
                    val_state = (state if not distributed else
                                 flax.jax_utils.unreplicate(state))
                    variables = val_state.variables
                    val_metrics = T.cast(utils.Metrics, val_metrics.reset())
                    for val_batch in val_iter:
                        val_metrics = val_step(val_batch, variables,
                                               val_metrics)
                    val_dict = dict(val_metrics.items())
                    val_dict['time'] = time.perf_counter() - start_time

                    ckpt_metric_val = val_dict[ckpt_metric]
                    if save_ckpts and cur_best < ckpt_metric_val:
                        # TODO: Add comparison option (i.e. less or more is better)
                        cur_best = ckpt_metric_val
                        checkpoints.save_checkpoint(f'ckpts_{ckpt_name}',
                                                    jax.device_get(val_state),
                                                    i,
                                                    keep=5)
                        if 'ckpt' in reporter.val_names:
                            val_dict['ckpt'] = 'Saved.'

                rep.report(i, train_dict, val_dict)
                if extra_report_fn is not None:
                    if distributed:
                        rstate, rbatch = flax.jax_utils.unreplicate(
                            (state, batch))
                    else:
                        rstate, rbatch = state, batch
                    extra_report_fn(rstate, rbatch, i)
                start_time = time.perf_counter()

    return state if not distributed else flax.jax_utils.unreplicate(state)
예제 #14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if FLAGS.dynamic:
        train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets(
            dataset_name=FLAGS.dataset_name,
            eval_dataset_name=FLAGS.eval_dataset_name,
            shard_idx=jax.process_index(),
            shard_count=jax.process_count(),
            data_dir=FLAGS.data_dir,
            vocab_path=FLAGS.vocab_path,
            target_vocab_size=FLAGS.vocab_size,
            batch_size=FLAGS.batch_size,
            max_length=FLAGS.max_target_length,
            max_eval_length=FLAGS.max_eval_target_length,
            paracrawl_size=FLAGS.paracrawl_size,
            is_scores_path=FLAGS.is_scores_path,
            num_buckets=FLAGS.num_data_buckets)
        if FLAGS.static:
            weights = np.array([float(w) for w in FLAGS.static.split(',')])
            assert len(weights) == FLAGS.num_data_buckets
            train_ds = train_ds_mgr.sampled_dataset(weights)
            FLAGS.dynamic = False
        else:
            init_dist = np.zeros(FLAGS.num_data_buckets)
            if FLAGS.data_selection_size < FLAGS.num_data_buckets:
                init_dist[range(FLAGS.data_selection_size)] = 1.0
                train_ds = train_ds_mgr.sampled_dataset(init_dist)
            else:
                train_ds = build_split(train_ds_mgr, 1.0)

    else:
        train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
            dataset_name=FLAGS.dataset_name,
            eval_dataset_name=FLAGS.eval_dataset_name,
            shard_idx=jax.process_index(),
            shard_count=jax.process_count(),
            data_dir=FLAGS.data_dir,
            vocab_path=vocab_path,
            target_vocab_size=FLAGS.vocab_size,
            batch_size=FLAGS.batch_size,
            max_length=FLAGS.max_target_length,
            max_eval_length=FLAGS.max_eval_target_length,
            paracrawl_size=FLAGS.paracrawl_size,
            is_scores_path=FLAGS.is_scores_path,
            num_to_keep=FLAGS.data_selection_size,
            pseudo_path=FLAGS.pseudo_path,
            repeat_count=FLAGS.repeat_count,
            newscommentary_size=FLAGS.newscommentary_size,
            split_tokenizer=FLAGS.split_tokenizer,
            sample_size=FLAGS.sample_size,
            newscomment_sample_ratio=FLAGS.newscomment_sample_ratio)

    if FLAGS.aux_eval_dataset:
        aux_datasets = []
        aux_names = FLAGS.aux_eval_dataset.split(',')
        for name in aux_names:
            _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
                dataset_name=name,
                eval_dataset_name=None,
                shard_idx=jax.process_index(),
                shard_count=jax.process_count(),
                data_dir=FLAGS.data_dir,
                vocab_path=vocab_path,
                target_vocab_size=FLAGS.vocab_size,
                batch_size=FLAGS.batch_size,
                max_length=FLAGS.max_target_length,
                max_eval_length=FLAGS.max_eval_target_length,
                paracrawl_size=FLAGS.paracrawl_size,
                is_scores_path=FLAGS.is_scores_path,
                num_to_keep=FLAGS.data_selection_size,
                pseudo_path=FLAGS.pseudo_path,
                repeat_count=FLAGS.repeat_count,
                newscommentary_size=FLAGS.newscommentary_size)
            aux_datasets.append(aux_eval_ds)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    if FLAGS.adapter != train_util.NONE:
        adapter = optim.ModelParamTraversal(
            lambda path, _: FLAGS.adapter in path)
        optimizer = optimizer_def.create(optimizer.target, focus=adapter)

    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)

    flag_key = [
        k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
    ]
    if flag_key:
        flag_key = flag_key[0]
        local_flags = {
            f.name: f.value
            for f in FLAGS.flags_by_module_dict()[flag_key]
        }
        writer.write_hparams(local_flags)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    if FLAGS.adapter != train_util.NONE:
        learning_rate_fn = common.create_learning_rate_scheduler(
            factors='constant',
            base_learning_rate=FLAGS.learning_rate,
            warmup_steps=FLAGS.warmup_steps)
    else:
        learning_rate_fn = common.create_learning_rate_scheduler(
            base_learning_rate=FLAGS.learning_rate,
            warmup_steps=FLAGS.warmup_steps,
            steps_per_cycle=FLAGS.steps_per_cycle,
            init_step=start_step,
            finetune_lr=FLAGS.finetune_lr)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_util.train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(train_util.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_util.initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(train_util.predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    p_get_diag_grads = jax.pmap(functools.partial(train_util.get_diag_grads,
                                                  config=eval_config),
                                axis_name='batch')

    p_get_bucket_score = jax.pmap(functools.partial(get_diag_score,
                                                    strategy=FLAGS.strategy),
                                  axis_name='batch')

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info('Starting training loop.')
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=FLAGS.model_dir,
                                     num_profile_steps=5)
        ]
    train_metrics = []
    total_steps = start_step + FLAGS.num_train_steps
    if FLAGS.eval_only:
        total_steps = start_step + 1
    best_eval_loss = 1000
    curr_eval_loss = 1000
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, total_steps):
            is_last_step = step == total_steps - 1

            if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq
                                  == 0):
                # Dynamic macro: use gradient alignment to score different ratios
                # of top k vs bottom N-k bins
                if FLAGS.macro:
                    train_iter = get_macro_distribution(
                        p_get_diag_grads, p_get_bucket_score, aux_eval_ds,
                        train_ds_mgr, optimizer, eval_ds)
                else:
                    # Use gradient alignment to score bins
                    # take the top k bins and sample uniformly from them.
                    raw_distribution = get_new_distribution(
                        p_get_diag_grads, p_get_bucket_score, aux_eval_ds,
                        train_ds_mgr, optimizer, eval_ds)
                    logging.info(raw_distribution)
                    selected = np.argsort(
                        raw_distribution)[::-1][:FLAGS.data_selection_size]
                    new_distribution = np.zeros(100)
                    new_distribution[selected] = 1.0
                    logging.info(new_distribution)
                    train_ds = train_ds_mgr.sampled_dataset(new_distribution)
                    train_iter = iter(train_ds)

            # Shard data to devices and do a training step.
            if not FLAGS.eval_only:
                logging.info('Doing Training.')
                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    try:
                        batch = common_utils.shard(
                            jax.tree_map(np.asarray, next(train_iter)))
                        optimizer, metrics = p_train_step(
                            optimizer, batch, dropout_rng=dropout_rngs)
                        train_metrics.append(metrics)
                    except StopIteration:
                        is_last_step = True

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step:
                if not FLAGS.eval_only:
                    with report_progress.timed('training_metrics'):
                        logging.info('Gathering training metrics.')
                        train_metrics = common_utils.get_metrics(train_metrics)
                        lr = train_metrics.pop('learning_rate').mean()
                        metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                        denominator = metrics_sums.pop('denominator')
                        summary = jax.tree_map(lambda x: x / denominator,
                                               metrics_sums)  # pylint: disable=cell-var-from-loop
                        summary['learning_rate'] = lr
                        summary = {'train_' + k: v for k, v in summary.items()}
                        writer.write_scalars(step, summary)
                        train_metrics = []

                if FLAGS.eval_only:
                    p_eval_per_pos_step = jax.pmap(functools.partial(
                        train_util.eval_per_pos_step, config=eval_config),
                                                   axis_name='batch')
                    # Get per example loss
                    loss_filename = FLAGS.model_dir + '/test_losses.csv'
                    train_util.write_per_example_losses(
                        p_eval_step=p_eval_per_pos_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=FLAGS.num_eval_steps,
                        loss_filename=loss_filename)
                else:
                    with report_progress.timed('eval'):
                        eval_results = train_util.evaluate(
                            p_eval_step=p_eval_step,
                            target=optimizer.target,
                            eval_ds=eval_ds,
                            num_eval_steps=FLAGS.num_eval_steps)
                        curr_eval_loss = eval_results['loss']
                        writer.write_scalars(
                            step,
                            {'eval_' + k: v
                             for k, v in eval_results.items()})

                if FLAGS.aux_eval_dataset:
                    for aux_i, aux_eval_ds in enumerate(aux_datasets):
                        with report_progress.timed('aux_eval'):
                            eval_results = train_util.evaluate(
                                p_eval_step=p_eval_step,
                                target=optimizer.target,
                                eval_ds=aux_eval_ds,
                                num_eval_steps=FLAGS.num_eval_steps)
                            writer.write_scalars(
                                step, {
                                    'aux' + str(aux_i) + '_eval_' + k: v
                                    for k, v in eval_results.items()
                                })

                if FLAGS.compute_bleu:
                    with report_progress.timed('translate_and_bleu'):
                        decode_file = FLAGS.model_dir + '/decodes.csv'
                        exemplars, bleu_score = train_util.translate_and_calculate_bleu(
                            p_pred_step=p_pred_step,
                            p_init_cache=p_init_cache,
                            target=optimizer.target,
                            predict_ds=predict_ds,
                            decode_tokens=decode_tokens,
                            max_predict_length=FLAGS.max_predict_length,
                            num_eval_steps=FLAGS.num_eval_steps,
                            decode_file=decode_file if FLAGS.eval_only else '')
                        writer.write_scalars(step, {'bleu': bleu_score})
                        writer.write_texts(step, {'samples': exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0
                               or is_last_step)
            if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                if curr_eval_loss < best_eval_loss:  # only save better checkpoints
                    best_eval_loss = curr_eval_loss
                    with report_progress.timed('checkpoint'):
                        checkpoints.save_checkpoint(
                            FLAGS.model_dir,
                            jax_utils.unreplicate(optimizer),
                            step,
                            keep=FLAGS.chkpts_to_keep,
                            overwrite=True)

            if is_last_step:
                break
예제 #15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceContext('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
예제 #16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': CLASS_MAP[FLAGS.task_name],
        'classifier_pool': config.classifier_pool
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    if model_type == 'transformer':
        model = create_model(init_rng, transformer.TransformerEncoder,
                             input_shape, model_kwargs)
    else:
        raise ValueError('Model type not supported')

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints or FLAGS.test_only:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    def run_eval(eval_ds, num_eval_steps=-1):
        eval_metrics = []
        eval_iter = iter(eval_ds)
        if num_eval_steps == -1:
            num_iter = itertools.count()
        else:
            num_iter = range(num_eval_steps)
        for _, eval_batch in zip(num_iter, eval_iter):
            # pylint: disable=protected-access
            eval_batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), eval_batch))
            # pylint: enable=protected-access
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        # Calculate (clipped) perplexity after averaging log-perplexities:
        eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']),
                                              a_max=1.0e4)
        return eval_summary

    if FLAGS.test_only:
        with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'),
                               'w') as f:
            test_summary = run_eval(test_ds)
            json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f)
        return

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_summary = run_eval(eval_ds, num_eval_steps)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()