Пример #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    config = FLAGS.config

    model = create_model(config)
    optimizer = create_optimizer(config, model)
    del model  # don't keep a copy of the initial model

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    start_step = int(optimizer.state[0].step)

    optimizer = optimizer.replicate()

    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
    tokenizer.model_max_length = config.max_seq_length

    # The commented lines below correspond to a data pipeline that uses publicly
    # available data, in the form of English Wikipedia as processed and hosted by
    # the HuggingFace datasets library. The pipeline works, and downstream task
    # performance shows a benefit to pre-training, but I (Nikita) have yet to
    # confirm that final model quality is on par with the original BERT.
    #
    # dataset = datasets.load_dataset('wikipedia', '20200501.en')['train']
    # data_pipeline = data.PretrainingDataPipelineV1(
    #   dataset, tokenizer,
    #   max_predictions_per_seq=config.max_predictions_per_seq)

    # The data pipeline below relies on having text files of Wikipedia + Books in
    # the same format as the original BERT data. That original data is not
    # publicly available, so you will need to provide your own. I (Nikita) have
    # had success using data from Gong et al. "Efficient Training of BERT by
    # Progressively Stacking", but this data was also obtained through private
    # correspondence and may not be generally available.
    # The data_files argument may be a list, if data is split across multiple
    # input files.
    dataset = datasets.load_dataset(
        'bert_data.py',
        data_files=os.path.expanduser('~/data/bert/corpus.train.tok'))['train']
    data_pipeline = data.PretrainingDataPipeline(
        dataset,
        tokenizer,
        max_predictions_per_seq=config.max_predictions_per_seq)

    datasets.logging.set_verbosity_error()

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_iter = data_pipeline.get_inputs(
            batch_size=config.train_batch_size, training=True)
        train_step_fn = training.create_train_step(
            compute_pretraining_loss_and_metrics, clip_grad_norm=1.0)

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, 'config.json')
                if not os.path.exists(config_path):
                    with open(config_path, 'w') as f:
                        json.dump({'model_type': 'bert', **config.model}, f)

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(compute_pretraining_stats,
                                          sample_feature_name='input_ids')
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            'loss':
            jnp.mean(eval_stats['loss']),
            'masked_lm_loss':
            jnp.mean(eval_stats['masked_lm_loss']),
            'next_sentence_loss':
            jnp.mean(eval_stats['next_sentence_loss']),
            'masked_lm_accuracy':
            jnp.sum(eval_stats['masked_lm_correct']) /
            jnp.sum(eval_stats['masked_lm_total']),
            'next_sentence_accuracy':
            jnp.sum(eval_stats['next_sentence_correct']) /
            jnp.sum(eval_stats['next_sentence_total']),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f'{name} = {val:.06f}'
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, 'eval_results.txt')
        with gfile.GFile(eval_results_path, 'w') as f:
            for line in eval_results:
                f.write(line + '\n')
Пример #2
0
def train_and_evaluate(
  random_seed, batch_size, learning_rate, num_train_steps, num_eval_steps,
  eval_freq, max_target_length, max_eval_target_length, weight_decay, data_dir,
  model_dir, restore_checkpoints, save_checkpoints, checkpoint_freq,
  max_predict_token_length, sampling_temperature, sampling_top_k, prompt_str):
  """Executes model training and evaluation loop.
  
  Args:
    random_seed: Seed for initializing PRNG random seed.
    batch_size: Batch size for training.
    learning_rate: Learning rate for the Adam optimizer.
    num_train_steps: Number of training steps.
    num_eval_steps: Number of evaluation steps.
    eval_freq: Frequency of evaluation during training.
    max_target_length: Maximum length of training examples.
    max_eval_target_length: Maximum length of eval examples.
    weight_decay: Decay factor for AdamW-style weight decay.
    data_dir: Directory containing TFDS lm1b/subwords32k datasets.
    model_dir: Directory where to store model data.
    restore_checkpoints: Whether to restore from existing model checkpoints.
    save_checkpoints: Whether to save model checkpoints.
    checkpoint_freq: Save a checkpoint every these number of steps.
    max_predict_token_length: Maximum example text inference token length.
    sampling_temperature: Sampling temperature for language model inference.
    sampling_top_k: Top k cutoff for logit sampling.
    prompt_str: Prompt for language model sampling.
  """
  if jax.host_id() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(model_dir, 'train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(model_dir, 'eval'))

  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
      n_devices=jax.local_device_count(),
      data_dir=data_dir,
      batch_size=batch_size,
      dynamic_batching=True,
      max_target_length=max_target_length,
      max_eval_target_length=max_eval_target_length)
  vocab_size = info_ds['text'].encoder.vocab_size
  encoder = info_ds['text'].encoder

  train_iter = iter(train_ds)
  input_shape = (batch_size, max_target_length)

  transformer_lm_kwargs = {
      'vocab_size': vocab_size,
      'emb_dim': 512,
      'num_heads': 8,
      'num_layers': 6,
      'qkv_dim': 512,
      'mlp_dim': 2048,
      'max_len': max(max_target_length, max_eval_target_length)
  }

  rng = random.PRNGKey(random_seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = random.split(rng)
  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())

  model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs)
  optimizer = create_optimizer(model, learning_rate, weight_decay)
  del model  # Don't keep a copy of the initial model.
  start_step = 0
  if restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)
  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')
  p_pred_step = jax.pmap(predict_step, axis_name='batch')

  metrics_all = []
  tick = time.time()
  for step, batch in zip(range(start_step, num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    # Save a Checkpoint
    if ((step % checkpoint_freq == 0 and step > 0) or
        step == num_train_steps - 1):
      if jax.host_id() == 0 and save_checkpoints:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            model_dir, jax_utils.unreplicate(optimizer), step)

    # Periodic metric handling.
    if step % eval_freq == 0 and step > 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()
      # Reset metric accumulation for next evaluation cycle.
      metrics_all = []

      # Eval Metrics
      eval_metrics = []
      eval_iter = iter(eval_ds)
      if num_eval_steps == -1:
        num_iter = itertools.repeat(1)
      else:
        num_iter = range(num_eval_steps)
      for _, eval_batch in zip(num_iter, eval_iter):
        # pylint: disable=protected-access
        eval_batch = common_utils.shard(
            jax.tree_map(lambda x: x._numpy(), eval_batch))
        # pylint: enable=protected-access
        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)
      # Calculate (clipped) perplexity after averaging log-perplexities:
      eval_summary['perplexity'] = jnp.clip(
          jnp.exp(eval_summary['loss']), a_max=1.0e4)
      logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss'])
      if jax.host_id() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()

      # Fast inference of prompt extension using trained LM.
      rng, subrng = jax.random.split(rng)
      pred_rngs = random.split(subrng, jax.local_device_count())
      prompt = jnp.array(encoder.encode(prompt_str))
      prompt = jax_utils.replicate(prompt)
      prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
      cache = jax_utils.replicate(
          cache_def.initialize_cache((1, max_predict_token_length)))
      predicted = p_pred_step(
        prompt, optimizer.target, cache, pred_rngs, max_predict_token_length,
        sampling_temperature, sampling_top_k)
      predicted = tohost(predicted)
      exemplars = ''
      for n in range(predicted.shape[0]):
        exemplars += encoder.decode(predicted[n]) + '\n\n'
      if jax.host_id() == 0:
        eval_summary_writer.text('samples', exemplars, step)
        eval_summary_writer.flush()
Пример #3
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_partial_programs,
                   FLAGS.max_program_length)
  split_io_shape = (FLAGS.per_device_batch_size,
                    FLAGS.num_strings_per_task,
                    FLAGS.num_partial_programs,
                    FLAGS.max_characters)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
    return inps, outs

  def decode_program(program):
    """Decode program tokens."""
    # Concatenate all partial programs.
    full_program = []
    for p in program:
      full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32))
    full_program = np.concatenate([full_program, [eos_token]], axis=0)

    try:
      return dsl.decode_program(full_program, id_token_table)
    except:  # pylint: disable=bare-except
      return None  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern,
      token_id_table,
      char_id_table,
      num_partial_programs=FLAGS.num_partial_programs)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5)
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = base_models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=not FLAGS.slow_decode)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.DecomposeExpandingLayerTransformer(
      config=eval_config, num_partial_programs=FLAGS.num_partial_programs,
      use_expanding_layer=FLAGS.use_expanding_layer)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  adam_opt_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = adam_opt_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if start_step > 0:
      start_step += 1

  # Build Pretraining Model and Optimizer (if specified)
  # ---------------------------------------------------------------------------
  pretrain_optimizer = None  # Optimizer used for pretrainined
  split_target = None  # Split pretrained model on partial programs.
  if start_step < FLAGS.num_pretrain_steps:
    # Load in pretraining optimizer.
    def filter_fn(path, value):
      del value
      if FLAGS.freeze_encoder and path.startswith('/encoder'):
        return False
      if FLAGS.freeze_decoder and path.startswith('/decoder'):
        return False
      return True
    trainable_weights = optim.ModelParamTraversal(filter_fn)
    pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def))
    pretrain_optimizer = pretrain_opt_def.create(optimizer.target)

    if FLAGS.pretrain_checkpoint_format:
      pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs
      checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs)

      if gfile.isdir(checkpoint_dir):
        # Use the pretrained parameters if no training has occurred yet.
        if start_step == 0:
          restore_paths = []
          if FLAGS.restore_encoder:
            restore_paths.append('target/encoder')
          if FLAGS.restore_decoder:
            restore_paths.append('target/decoder')

          pretrain_optimizer = restore_selected_paths(
              pretrain_optimizer,
              checkpoint_dir=checkpoint_dir,
              restore_paths=restore_paths)
          logging.info('Found model pretrained at %s.', checkpoint_dir)

        if FLAGS.match_split_encoding:
          split_model = models.DecomposeExpandingLayerTransformer(
              config=eval_config, num_partial_programs=1,
              use_expanding_layer=False)
          split_program_shape = (FLAGS.per_device_batch_size,
                                 1,
                                 FLAGS.max_program_length)
          split_initial_variables = jax.jit(split_model.init)(
              init_rng,
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(split_program_shape, jnp.float32))
          split_optimizer = adam_opt_def.create(
              split_initial_variables['params'])
          split_optimizer = checkpoints.restore_checkpoint(
              checkpoint_dir, split_optimizer)
          split_target = split_optimizer.target
      else:
        logging.warn('Could not find model at %s.', checkpoint_dir)

    if FLAGS.match_split_encoding and (split_target is None):
      raise RuntimeError('We could not load the pretrained checkpoint, '
                         'which is needed to match split embeddings.')

  learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr)
  p_pretrain_step = jax.pmap(
      functools.partial(
          pretrain_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer,
          split_params=split_target),
      axis_name='batch')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          num_partial_programs=FLAGS.num_partial_programs,
          eos_token=eos_token,
          config=eval_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_split_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=False,
          use_split_encoding=True,
          split_params=split_target),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  # Replicate optimizer.
  if pretrain_optimizer:
    pretrain_optimizer = jax_utils.replicate(pretrain_optimizer)

  optimizer = jax_utils.replicate(optimizer)

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs, split_outputs = (
        common_utils.shard(next(train_iter)))

    if step < FLAGS.num_pretrain_steps:
      pretrain_optimizer, metrics, train_rngs = p_pretrain_step(
          pretrain_optimizer, inputs, outputs, programs,
          split_outputs=split_outputs,
          pretrain_rng=train_rngs)
    else:
      optimizer, metrics, train_rngs = p_train_step(
          optimizer, inputs, outputs, programs,
          train_rng=train_rngs)

    metrics_all.append(metrics)
    is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1
    is_last_step = step == FLAGS.num_train_steps - 1

    if is_last_pretrain_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)

    # Save a Checkpoint
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or (step % FLAGS.log_freq != 0 and not is_last_step and
                    not is_last_pretrain_step):
      continue

    optimizer = maybe_copy_model_from_pretraining(
        optimizer, pretrain_optimizer, step, adam_opt_def)

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()

    eval_summary = evaluate(
        p_eval_step=p_eval_step,
        target=optimizer.target,
        eval_ds=eval_ds)
    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [1, 10, 12, 24, 48, 96]:
      t_inference_start = time.time()

      pred_acc, message = predict_and_compute_score(
          p_pred_step=p_pred_step,
          p_init_cache=p_init_cache,
          target=optimizer.target,
          predict_ds=predict_ds,
          decode_io=decode_io,
          decode_program=decode_program,
          beam_size=beam_size,
          num_partial_programs=FLAGS.num_partial_programs,
          use_best_first_search=FLAGS.best_first_search,
          slow_decode=FLAGS.slow_decode)

      # Write to tensorboard.
      if jax.host_id() == 0:
        slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
        logging.info(
            'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f',
            slow_or_fast, beam_size, time.time() - t_inference_start, step,
            pred_acc)
        beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search'
        summary_writer.scalar(
            'predict-{}/score-{}-{}'.format(slow_or_fast,
                                            beam_search_or_bfs,
                                            beam_size),
            pred_acc, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()

      if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding:
        pred_acc, message = predict_and_compute_score(
            p_pred_step=p_split_pred_step,
            p_init_cache=p_init_cache,
            target=optimizer.target,
            predict_ds=predict_ds,
            decode_io=decode_io,
            decode_program=decode_program,
            beam_size=beam_size,
            num_partial_programs=FLAGS.num_partial_programs,
            use_best_first_search=FLAGS.best_first_search,
            slow_decode=FLAGS.slow_decode)

        # Write to tensorboard.
        if jax.host_id() == 0:
          slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
          beam_search_or_bfs = ('bfs' if FLAGS.best_first_search
                                else 'beam-search')
          summary_writer.scalar(
              'predict-split-{}/score-{}-{}'.format(slow_or_fast,
                                                    beam_search_or_bfs,
                                                    beam_size),
              pred_acc, step)
          summary_writer.text('samples-split-{}'.format(beam_size),
                              '\n------\n'.join(message), step)
          summary_writer.flush()
Пример #4
0
def main(argv):
    global CFG
    CFG = FLAGS.config

    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
    _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if CFG.hardware_rng:
        models.set_hardware_bernoulli()

    if 'module_import' in CFG and CFG.module_import:
        for module in CFG.module_import:
            importlib.import_module(module)

    if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
        t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

    num_partitions = CFG.num_partitions
    topology = train_lib.compute_multihost_topology(num_partitions)
    batch_size = CFG.batch_size
    eval_batch_size = CFG.eval_batch_size
    per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
    if batch_size % topology.num_replicas:
        raise ValueError(
            'Batch size must be divisible by the number of replicas.')

    steps_per_epoch = CFG.steps_per_epoch
    logging.info('steps per epoch: %d', steps_per_epoch)

    broadcast = functools.partial(
        train_lib.broadcast,
        num_replicas=topology.per_replica_set_num_replicas,
        num_partitions=topology.per_host_num_partitions,
        devices=topology.this_host_device_assignment)

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        tf.io.gfile.copy(FLAGS['config'].config_filename,
                         os.path.join(FLAGS.model_dir, 'config.py'),
                         overwrite=True)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    # Write summaries in background thread to avoid blocking on device sync
    if CFG.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
        CFG, topology.num_replica_sets, topology.replica_set_id,
        topology.per_replica_set_host_id)

    vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
    encoder = vocab.tf_tokenizer
    eos_id = vocab.tokenizer.eos_id()

    def decode_tokens(toks, eos_id=eos_id, max_id=32000):
        """Decode tokens back to unicode."""
        del eos_id
        # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
        # is the best decoding function or just switch to using tf_decode.
        # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        valid_toks = toks.astype(np.int32)
        valid_toks[valid_toks >= max_id] = 3
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    train_config, eval_config, predict_config = get_configs(CFG)

    rng = random.PRNGKey(CFG.random_seed)
    rng, init_rng = random.split(rng)
    # This is used for infeed conversion from feature dict <--> tuple
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    device_train_input_shape = tuple([
        (batch_size // topology.num_replicas,
         CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
        for k in train_keys
    ])

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        factors=CFG.schedule,
        base_learning_rate=CFG.learning_rate,
        warmup_steps=CFG.warmup_steps)

    # First, we only abstractly initialize the optimizer and model parameters,
    # since the parameters may not even fit in device memory!
    # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
    # in get_initial_params without causing pytree incompatibility
    optimizer_def = optim.Adafactor(CFG.learning_rate,
                                    decay_rate=0.8,
                                    step_offset=CFG.step_offset)
    initialize_params_fn = functools.partial(get_initial_params,
                                             config=CFG,
                                             transformer_config=eval_config,
                                             optimizer_def=optimizer_def)
    optimizer = jax.eval_shape(initialize_params_fn, init_rng)
    # tuple-like pytree leaves for global_arg_shapes
    optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                    optimizer)

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    if num_partitions > 1:
        optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(num_partitions, optimizer.state_dict()))
        per_host_optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(topology.per_host_num_partitions,
                                      optimizer.state_dict()))

    # Restore unreplicated optimizer + model state from last checkpoint.
    # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
    existing_checkpoint_found = False
    if CFG.restore_checkpoints:
        existing_checkpoint_found = train_lib.checkpoint_exists(
            FLAGS.model_dir)
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

    # Import a pretrained-T5 checkpoint only if we didn't import a local
    # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
    # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
    if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
        optimizer = checkpoint_importer.restore_from_t5_checkpoint(
            optimizer, CFG.restore_t5_checkpoint)

    if CFG.restore_t5_checkpoint or existing_checkpoint_found:
        if num_partitions > 1:
            # Until checkpoint/restore is sharded, the restored checkpoint is global
            # and we need to slice each sharded parameter into the chunk containing
            # only the partitions that are present on this host.
            def per_host_chunk(x, spec):
                if spec is None or spec is x:  # unsharded or not a parameter
                    return x
                if spec[0] == 1:
                    dim_size = x.shape[1]
                elif spec[1] == 1:
                    dim_size = x.shape[0]
                else:
                    raise NotImplementedError()
                chunk_size = (dim_size * topology.per_host_num_partitions //
                              num_partitions)
                lower = topology.per_replica_set_host_id * chunk_size
                upper = (topology.per_replica_set_host_id + 1) * chunk_size
                if spec[0] == 1:
                    return x[:, lower:upper]
                else:
                    return x[lower:upper]

            optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                          optimizer_partitions)
    else:
        # If pretraining and no checkpoint imported, we jit the (sharded-) init
        # function to minimize fragmentation. We use the same pmap(sharded_jit)
        # setup as the training step/loop to initialize everything "in-place" and
        # avoid communication or OOM.
        if num_partitions > 1:
            initialize_params_fn = sharded_jit(
                initialize_params_fn,
                in_parts=None,
                local_in_parts=None,
                out_parts=optimizer_partitions,
                local_out_parts=per_host_optimizer_partitions,
                # devices=one_replica_device_assignment,
            )
            initialize_params_fn = jax.pmap(initialize_params_fn,
                                            'batch',
                                            in_axes=0,
                                            axis_size=topology.num_replicas,
                                            devices=topology.device_assignment)
            init_rng = broadcast(init_rng)
            optimizer = initialize_params_fn(init_rng)
            # We maintain the optimizer in unbroadcasted form (i.e. with no leading
            # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
            # out_axes=None.
            optimizer = train_lib.unbroadcast(optimizer)
        else:
            optimizer = jax.jit(initialize_params_fn)(init_rng)

    # ---------------------------------------------------------------------------
    # Compile multidevice versions of train/eval/predict step and cache init fn.
    # ---------------------------------------------------------------------------

    # We can use either a single train-step for a host training loop:

    # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
    #  --> new_optimizer, metrics, new_dropout_rng
    def p_train_step(optimizer, batch, prev_metrics, dropout_rng):
        return train_lib.train_step(optimizer,
                                    batch,
                                    prev_metrics,
                                    dropout_rng,
                                    config=train_config,
                                    learning_rate_fn=learning_rate_fn,
                                    num_microbatches=CFG.microbatches,
                                    label_smoothing=CFG.label_smoothing,
                                    z_loss=CFG.z_loss,
                                    use_bfloat16=CFG.use_bfloat16)

    if num_partitions > 1:
        p_train_step = sharded_jit(
            p_train_step,
            in_parts=(optimizer_partitions, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None),
            out_parts=(optimizer_partitions, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None))
    # TODO(levskaya): the in_axes spec below might be wrong, double-check.
    p_train_step = jax.pmap(p_train_step,
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0),
                            donate_argnums=(0, ),
                            global_arg_shapes=(optimizer_shapes, None, None,
                                               None),
                            axis_size=topology.num_replicas,
                            devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # OR, we use an on-device loop that feeds the training step via infeed queue.
    def device_train_loop_cond(args):
        """Stopping criterion for on-device loop."""
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        """On-device loop body."""
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        # Ordering input data from infeed requires threading a symbolic token
        # through the computation.
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(s, jnp.int32)
                                           for s in device_train_input_shape
                                       ]))
        # Rebuild input dict from infeed data tuple.
        batch = {k: v for k, v in zip(train_keys, input_data)}
        # Run the train_step function and return the loop state.
        optimizer, metrics, dropout_rngs = train_lib.train_step(
            optimizer,
            batch,
            metrics,
            dropout_rngs,
            train_config,
            learning_rate_fn,
            num_microbatches=CFG.microbatches,
            label_smoothing=CFG.label_smoothing,
            z_loss=CFG.z_loss)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        # Create symbolic token for threading infeed data.
        token = lax.create_token(step)
        # Run on-device loop.
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(
            device_train_loop,
            in_parts=(optimizer_partitions, None, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None,
                            None),
            out_parts=(optimizer_partitions, None, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None),
                             donate_argnums=(0, ),
                             global_arg_shapes=(optimizer_shapes, None, None,
                                                None, None),
                             axis_size=topology.num_replicas,
                             devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Reduction psum for metric data.

    def p_allreduce_metrics(x):
        return lax.psum(x, axis_name='batch')

    if num_partitions > 1:
        p_allreduce_metrics = sharded_jit(
            p_allreduce_metrics,
            in_parts=None,
            local_in_parts=None,
            out_parts=None,
            local_out_parts=None,
            num_partitions=num_partitions,
            local_num_partitions=topology.per_host_num_partitions)
    p_allreduce_metrics = jax.pmap(p_allreduce_metrics,
                                   axis_name='batch',
                                   global_arg_shapes=None,
                                   axis_size=topology.num_replicas,
                                   devices=topology.device_assignment)

    # Training evaluation computation.

    # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
    def p_eval_step(params, batch):
        return train_lib.eval_step(params,
                                   batch,
                                   config=eval_config,
                                   label_smoothing=CFG.label_smoothing)

    if num_partitions > 1:
        p_eval_step = sharded_jit(
            p_eval_step,
            in_parts=(optimizer_partitions.target, None),
            local_in_parts=(per_host_optimizer_partitions.target, None),
            out_parts=None,
            local_out_parts=None)
    p_eval_step = jax.pmap(p_eval_step,
                           axis_name='batch',
                           in_axes=(None, 0),
                           global_arg_shapes=(optimizer_shapes.target, None),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Fast autoregressive decoding loop.
    # For inference and model evaluation.

    # predict_step(inputs, params,
    #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
    def p_pred_step(inputs, params):
        return train_lib.predict_step(inputs, params, eos_id,
                                      CFG.max_eval_target_length,
                                      predict_config, CFG.beam_size)

    if num_partitions > 1:
        p_pred_step = sharded_jit(
            p_pred_step,
            in_parts=(None, optimizer_partitions.target),
            local_in_parts=(None, per_host_optimizer_partitions.target),
            out_parts=None,
            local_out_parts=None)
    p_pred_step = jax.pmap(p_pred_step,
                           axis_name='batch',
                           in_axes=(0, None),
                           global_arg_shapes=(None, optimizer_shapes.target),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # ---------------------------------------------------------------------------
    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    # There should be a unique dropout key for each replica represented on this
    # host, but the key should be the same for the same replica on other hosts.
    # Again, this is what the replica set abstraction is for.
    dropout_rngs = random.split(random.fold_in(rng, topology.replica_set_id),
                                topology.per_replica_set_num_replicas)
    # restore step from last checkpoint
    host_step = int(optimizer.state.step)
    empty_metrics = broadcast({
        'loss': 0.0,
        'accuracy': 0.0,
        'learning_rate': 0.0,
        'denominator': 0.0
    })
    if CFG.infeed:
        # TODO(jekbradbury): support something like this for the Python-loop case
        logging.info(
            'Precompiling training loop and moving optimizer to device.')
        optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                                 empty_metrics,
                                                 jnp.array(0,
                                                           dtype=jnp.int32), 1)
        optimizer = train_lib.unbroadcast(optimizer)
        metrics['loss'].block_until_ready()

    logging.info('Starting training loop.')

    local_devices = jax.local_devices()
    device_step = broadcast(host_step)
    first_epoch = host_step // steps_per_epoch

    # Main Loop over "epochs".
    train_iter = train_ds.as_numpy_iterator()
    for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
        metrics = empty_metrics

        # NOTE: 'optimizer' is unbroadcast by construction at initialization or
        # when loading a checkpoint. It is maintained in 'unbroadcast' state to
        # enable the XLA cross-replica sharding optimization.  The broadcasting is
        # handled automatically by the pmap'd functions that use it.

        # Gather all task evaluation metrics.
        logging.info('Evaluating tasks.')
        if epoch == first_epoch + 1:
            train_lib.sync_devices()
        for task in eval_cache.tasks:
            logging.info('Evaluating task %s', task.name)
            all_predicted, all_bs = [], []
            for pred_batch in eval_cache.preprocessed_examples[task.name]:
                # Handle final odd-sized batch by padding instead of dropping it.
                input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
                    pred_batch['inputs'], per_replica_set_eval_batch_size)
                all_bs.append(unpadded_batch_size)
                # Split batch dimensions for pmap.
                input_batch = jax.tree_map(
                    lambda x: x.reshape((topology.per_replica_set_num_replicas,
                                         -1) + x.shape[1:]), input_batch)
                # Run fast inference on batch.
                all_predicted.append(p_pred_step(input_batch,
                                                 optimizer.target))

            # Pad out the number of batches so each host has the same number.
            max_host_batch_number = np.max(
                eval_cache.preprocessed_batch_sizes[task.name])
            batch_shortfall = max_host_batch_number - len(all_predicted)
            if batch_shortfall > 0:
                # TODO(levskaya): Fix for case of entirely empty all_predicted.
                # To make sure the cross-host barriers work, we run the program the same
                # number of times on all hosts. The results of this call is ignored, and
                # the predictions are populated with zeros instead.
                p_pred_step(input_batch, optimizer.target)  # Dummy call.
                all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                                     batch_shortfall)
                all_bs.extend([0] * batch_shortfall)
            all_predicted = jnp.concatenate(all_predicted)
            all_bs = jnp.array(all_bs)

            # Collect all batches from across hosts and reverse sharding.
            all_predicted = train_lib.host_allgather(
                all_predicted, topology.num_replica_sets,
                topology.replica_set_id, topology.per_replica_set_host_id == 0)
            seqlength = all_predicted.shape[-1]
            total_examples = np.sum(
                train_lib.host_allgather(
                    all_bs, topology.num_replica_sets, topology.replica_set_id,
                    topology.per_replica_set_host_id == 0))
            del all_bs
            assert total_examples == len(eval_cache.examples[task.name]), (
                'Total number of batches incorrect for task %s.' % task.name)
            # De-shard the collected predicted tokens and remove padding.
            all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
                -1, seqlength)[:total_examples]

            # We now run the post-processing and metric-fns on a single host.
            if jax.host_id() == 0:
                assert eval_summary_writer
                raw_predictions = []
                for tokens in all_predicted:
                    raw_predictions.append(decode_tokens(tokens))

                # post-process predictions for metric fns
                predictions = [
                    task.postprocess_fn(p, example=ex) for p, ex in zip(
                        raw_predictions, eval_cache.examples[task.name])
                ]

                for metric_fn in task.metric_fns:
                    scores = metric_fn(eval_cache.targets[task.name],
                                       predictions)
                    for metric_name, metric_value in scores.items():
                        tag = f'eval/{task.name}/{metric_name}'
                        eval_summary_writer.scalar(tag, metric_value,
                                                   host_step)
                        logging.info('EVAL %s at step %d: %.3f', tag,
                                     host_step, metric_value)
                    eval_summary_writer.flush()

                # Save text samples for tensorboard.
                exemplars = ''
                for n in np.random.choice(np.arange(len(predictions)), 8):
                    tgt_txt = tf.compat.as_text(
                        eval_cache.examples[task.name][n]['targets_plaintext'])
                    pred_txt = raw_predictions[n]
                    exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                                  f'target: {tgt_txt}\n\n'
                                  f'prediction: {pred_txt}\n\n')
                eval_summary_writer.text(f'{task.name} samples', exemplars,
                                         host_step)
                eval_summary_writer.flush()

        # Take an Xprof trace after the first loop has compiled everything.
        if epoch == first_epoch + 1:
            train_lib.sync_devices()

        # For on-device loop, we launch the computation before feeding data.
        logging.info('BEGIN Train loop.')
        if CFG.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                optimizer, dropout_rngs, metrics,
                train_lib.unbroadcast(device_step), epoch)
            optimizer = train_lib.unbroadcast(optimizer)

        # Epoch loop.
        while int(host_step // steps_per_epoch) == epoch:
            batch = next(train_iter)
            batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                batch)
            # Feed the on-device training loop.
            if CFG.infeed:
                for i, device in enumerate(local_devices):
                    # When using infeed to provide data to the computation, we're on our
                    # own for feeding the right values to the right devices. Each device
                    # should get the minibatch corresponding to its replica, a slice of
                    # the larger batch corresponding to the host's replica set.
                    if device.platform == 'tpu':
                        device_coords = (*device.coords, device.id % 2)
                    else:
                        device_coords = (device.host_id, i)
                    per_replica_set_device_coords = tuple(
                        dc % prsm for dc, prsm in zip(
                            device_coords, topology.per_replica_set_mesh))
                    per_replica_set_replica_coords = tuple(
                        prsdc // prm
                        for prsdc, prm in zip(per_replica_set_device_coords,
                                              topology.per_replica_mesh))
                    per_replica_set_replica_id = 0
                    for prsm, prm, prsrc in zip(
                            topology.per_replica_set_mesh,
                            topology.per_replica_mesh,
                            per_replica_set_replica_coords):
                        per_replica_set_replica_id = (
                            per_replica_set_replica_id * prsm // prm + prsrc)
                    input_tuple = tuple([
                        batch[k][per_replica_set_replica_id]
                        for k in train_keys
                    ])
                    # Safety check: infeed does not check shape or types but requires
                    # them to agree with on-device spec, otherwise the queue and program
                    # stalls.
                    tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
                    tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
                    assert tuple_shapes == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (tuple_shapes, device_train_input_shape))
                    assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
                        ('infeed dtype error %s not all of type %s' % (
                            tuple_dtypes, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            # Host training loop.
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    optimizer, batch, metrics, dropout_rngs)
                optimizer = train_lib.unbroadcast(optimizer)
            host_step += 1
        logging.info('END Train loop.')

        # Maybe save a checkpoint on one host.
        if (CFG.save_checkpoints
                and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1
                and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

        # Gather training metrics.
        metrics = p_allreduce_metrics(metrics)
        metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
        denominator = metrics.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
        logging.info('train in step: %s, %s', host_step, summary)
        if jax.host_id() == 0:
            assert train_summary_writer
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, host_step)
            train_summary_writer.flush()

        # Gather training evaluation metrics.
        logging.info('Gathering training evaluation metrics.')
        eval_metrics = []
        eval_iter = eval_ds.as_numpy_iterator()
        for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        # average metrics across devices
        eval_metrics = p_allreduce_metrics(eval_metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        # average metrics across steps
        eval_metrics = jax.tree_map(np.sum, eval_metrics)
        eval_denominator = eval_metrics.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics)
        logging.info('eval in step: %s, %s', host_step, eval_summary)
        if jax.host_id() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, host_step)
            eval_summary_writer.flush()

    # Wait until computations are done before exiting
    logging.info('Finished.')
    train_lib.sync_devices()
    # Shut down the infeed threadpool.
    if CFG.infeed:
        infeed_pool.shutdown()
Пример #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': FLAGS.emb_dim,
        'num_heads': FLAGS.num_heads,
        'num_layers': FLAGS.num_layers,
        'qkv_dim': FLAGS.qkv_dim,
        'mlp_dim': FLAGS.mlp_dim,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          use_bfloat16=FLAGS.use_bfloat16,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0 and step > 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_id, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Пример #6
0
def main(unused_argv):
  config = utils.load_config()

  dataset = datasets.get_dataset('test', FLAGS.data_dir, config)
  model, init_variables = models.construct_mipnerf(
      random.PRNGKey(20200823), dataset.peek())
  optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates 'speckle' artifacts.
  def render_eval_fn(variables, _, rays):
    return jax.lax.all_gather(
        model.apply(
            variables,
            random.PRNGKey(0),  # Unused.
            rays,
            randomized=False,
            white_bkgd=config.white_bkgd),
        axis_name='batch')

  # pmap over only the data input.
  render_eval_pfn = jax.pmap(
      render_eval_fn,
      in_axes=(None, None, 0),
      donate_argnums=2,
      axis_name='batch',
  )

  ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.))

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      'path_renders' if config.render_path else 'test_preds')
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, 'eval'))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnr_values = []
    ssim_values = []
    avg_values = []
    if not FLAGS.eval_once:
      showcase_index = random.randint(random.PRNGKey(step), (), 0, dataset.size)
    for idx in range(dataset.size):
      print(f'Evaluating {idx+1}/{dataset.size}')
      batch = next(dataset)
      pred_color, pred_distance, pred_acc = models.render_image(
          functools.partial(render_eval_pfn, state.optimizer.target),
          batch['rays'],
          None,
          chunk=FLAGS.chunk)

      vis_suite = vis.visualize_suite(pred_distance, pred_acc)

      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_acc = pred_acc
        showcase_vis_suite = vis_suite
        if not config.render_path:
          showcase_gt = batch['pixels']
      if not config.render_path:
        psnr = float(
            math.mse_to_psnr(((pred_color - batch['pixels'])**2).mean()))
        ssim = float(ssim_fn(pred_color, batch['pixels']))
        print(f'PSNR={psnr:.4f} SSIM={ssim:.4f}')
        psnr_values.append(psnr)
        ssim_values.append(ssim)
      if FLAGS.save_output and (config.test_render_interval > 0):
        if (idx % config.test_render_interval) == 0:
          utils.save_img_uint8(
              pred_color, path.join(out_dir, 'color_{:03d}.png'.format(idx)))
          utils.save_img_float32(
              pred_distance,
              path.join(out_dir, 'distance_{:03d}.tiff'.format(idx)))
          utils.save_img_float32(
              pred_acc, path.join(out_dir, 'acc_{:03d}.tiff'.format(idx)))
          for k, v in vis_suite.items():
            utils.save_img_uint8(
                v, path.join(out_dir, k + '_{:03d}.png'.format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image('pred_color', showcase_color, step)
      summary_writer.image('pred_acc', showcase_acc, step)
      for k, v in showcase_vis_suite.items():
        summary_writer.image('pred_' + k, v, step)
      if not config.render_path:
        summary_writer.scalar('psnr', np.mean(np.array(psnr_values)), step)
        summary_writer.scalar('ssim', np.mean(np.array(ssim_values)), step)
        summary_writer.image('target', showcase_gt, step)
    if FLAGS.save_output and (not config.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f'psnrs_{step}.txt'), 'w') as f:
        f.write(' '.join([str(v) for v in psnr_values]))
      with utils.open_file(path.join(out_dir, f'ssims_{step}.txt'), 'w') as f:
        f.write(' '.join([str(v) for v in ssim_values]))
    if FLAGS.eval_once:
      break
    if int(step) >= config.max_steps:
      break
    last_step = step
Пример #7
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  if FLAGS.jax_backend_target:
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.batch_size % n_devices:
    raise ValueError('Batch size must be divisible by the number of devices')

  vocab_path = FLAGS.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if FLAGS.dynamic:
    train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=FLAGS.vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_buckets=FLAGS.num_data_buckets)
    if FLAGS.static:
      weights = np.array([float(w) for w in FLAGS.static.split(',')])
      assert len(weights) == FLAGS.num_data_buckets
      train_ds = train_ds_mgr.sampled_dataset(weights)
      FLAGS.dynamic = False
    else:
      init_dist = np.zeros(FLAGS.num_data_buckets)
      if FLAGS.data_selection_size < FLAGS.num_data_buckets:
        init_dist[range(FLAGS.data_selection_size)] = 1.0
        train_ds = train_ds_mgr.sampled_dataset(init_dist)
      else:
        train_ds = build_split(train_ds_mgr, 1.0)

  else:
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size)

  if FLAGS.aux_eval_dataset:
    aux_datasets = []
    aux_names = FLAGS.aux_eval_dataset.split(',')
    for name in aux_names:
      _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
          dataset_name=name,
          eval_dataset_name=None,
          shard_idx=jax.host_id(),
          shard_count=jax.host_count(),
          data_dir=FLAGS.data_dir,
          vocab_path=vocab_path,
          target_vocab_size=FLAGS.vocab_size,
          batch_size=FLAGS.batch_size,
          max_length=FLAGS.max_target_length,
          max_eval_length=FLAGS.max_eval_target_length,
          paracrawl_size=FLAGS.paracrawl_size,
          is_scores_path=FLAGS.is_scores_path,
          num_to_keep=FLAGS.data_selection_size,
          pseudo_path=FLAGS.pseudo_path,
          repeat_count=FLAGS.repeat_count,
          newscommentary_size=FLAGS.newscommentary_size)
      aux_datasets.append(aux_eval_ds)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=FLAGS.share_embeddings,
      logits_via_embedding=FLAGS.logits_via_embedding,
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
      emb_dim=FLAGS.emb_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.qkv_dim,
      mlp_dim=FLAGS.mlp_dim,
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(FLAGS.random_seed)
  rng, init_rng = jax.random.split(rng)
  # It's possible that is supposed to be per device batch size
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      FLAGS.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if FLAGS.restore_checkpoints:
    logging.info('Restoring checkpoint.')
    # If we have a pretrained model, use that. Else, just continue where leftoff
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
    optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  if FLAGS.adapter != NONE:
    adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
    optimizer = optimizer_def.create(optimizer.target, focus=adapter)

  writer = metric_writers.create_default_writer(
      FLAGS.model_dir, just_logging=jax.process_index() > 0)

  flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
             ]
  if flag_key:
    flag_key = flag_key[0]
    local_flags = {
        f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
    }
    writer.write_hparams(local_flags)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  if FLAGS.adapter != NONE:
    learning_rate_fn = common.create_learning_rate_scheduler(
        factors='constant',
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)
  else:
    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=FLAGS.label_smoothing),
      axis_name='batch',
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(eval_step, config=eval_config), axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_predict_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=FLAGS.beam_size),
      axis_name='batch',
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  p_get_diag_grads = jax.pmap(
      functools.partial(
          get_diag_grads,
          config=eval_config),
      axis_name='batch')

  p_get_bucket_score = jax.pmap(
      functools.partial(
          get_diag_score,
          strategy=FLAGS.strategy),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=FLAGS.num_train_steps, writer=writer)
  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  total_steps = start_step + FLAGS.num_train_steps
  best_eval_loss = 1000
  curr_eval_loss = 1000
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, total_steps):
      is_last_step = step == total_steps - 1

      if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0):
        # Dynamic macro: use gradient alignment to score different ratios
        # of top k vs bottom N-k bins
        if FLAGS.macro:
          train_iter = get_macro_distribution(p_get_diag_grads,
                                              p_get_bucket_score, aux_eval_ds,
                                              train_ds_mgr, optimizer, eval_ds)
        else:
          # Use gradient alignment to score bins
          # take the top k bins and sample uniformly from them.
          raw_distribution = get_new_distribution(p_get_diag_grads,
                                                  p_get_bucket_score,
                                                  aux_eval_ds, train_ds_mgr,
                                                  optimizer,
                                                  eval_ds)
          logging.info(raw_distribution)
          selected = np.argsort(
              raw_distribution)[::-1][:FLAGS.data_selection_size]
          new_distribution = np.zeros(100)
          new_distribution[selected] = 1.0
          logging.info(new_distribution)
          train_ds = train_ds_mgr.sampled_dataset(new_distribution)
          train_iter = iter(train_ds)

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        try:
          batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
          optimizer, metrics = p_train_step(
              optimizer, batch, dropout_rng=dropout_rngs)
          train_metrics.append(metrics)
        except StopIteration:
          is_last_step = True

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop('learning_rate').mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop('denominator')
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary['learning_rate'] = lr
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=FLAGS.num_eval_steps)
          curr_eval_loss = eval_results['loss']
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

        if FLAGS.aux_eval_dataset:
          for aux_i, aux_eval_ds in enumerate(aux_datasets):
            with report_progress.timed('aux_eval'):
              eval_results = evaluate(
                  p_eval_step=p_eval_step,
                  target=optimizer.target,
                  eval_ds=aux_eval_ds,
                  num_eval_steps=FLAGS.num_eval_steps)
              writer.write_scalars(
                  step, {
                      'aux' + str(aux_i) + '_eval_' + k: v
                      for k, v in eval_results.items()
                  })

        if FLAGS.compute_bleu:
          with report_progress.timed('translate_and_bleu'):
            exemplars, bleu_score = translate_and_calculate_bleu(
                p_pred_step=p_pred_step,
                p_init_cache=p_init_cache,
                target=optimizer.target,
                predict_ds=predict_ds,
                decode_tokens=decode_tokens,
                max_predict_length=FLAGS.max_predict_length)
            writer.write_scalars(step, {'bleu': bleu_score})
            writer.write_texts(step, {'samples': exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
                         is_last_step)
      if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        if curr_eval_loss < best_eval_loss:  # only save better checkpoints
          best_eval_loss = curr_eval_loss
          with report_progress.timed('checkpoint'):
            checkpoints.save_checkpoint(
                FLAGS.model_dir, jax_utils.unreplicate(optimizer),
                step, keep=FLAGS.chkpts_to_keep, overwrite=True)

      if is_last_step:
        break
Пример #8
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    # https://arxiv.org/abs/2106.10270
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceContext('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
def main(_):

    if FLAGS.use_transformer:
        assert (
            FLAGS.encoder_fn_name == 'transformer'
        ), 'encoder_fn_name must be transformer if use_transformer is True!'

    assert (FLAGS.epochs % FLAGS.measurements == 0
            ), 'Number of measurements must divide number of epochs!'
    measurement_epochs = FLAGS.epochs // FLAGS.measurements

    assert FLAGS.results_save_dir != '', 'Specify results_save_dir!'

    assert FLAGS.label != '', 'Specify label!'

    if FLAGS.load_model:
        assert FLAGS.load_model_dir != '', 'Specify load_model_dir!'
        assert FLAGS.load_model_step > 0, 'Loaded model must have been trained for more than 0 steps.'

    if FLAGS.save_model:
        assert FLAGS.save_model_dir != '', 'Specify save_model_dir!'

    datum = {
        'label': FLAGS.label,
        'encoder_fn_name': FLAGS.encoder_fn_name,
        'encoder_fn_kwargs_path': FLAGS.encoder_fn_kwargs_path,
        'reduce_fn_name': FLAGS.reduce_fn_name,
        'reduce_fn_kwargs_path': FLAGS.reduce_fn_kwargs_path,
        'epochs': FLAGS.epochs,
        'measurements': FLAGS.measurements,
        'lens_batch_size': FLAGS.lens_batch_size,
        'knn_batch_size': FLAGS.knn_batch_size,
        'encoder_lr': FLAGS.encoder_lr,
        'lens_lr': FLAGS.lens_lr,
        'predictor_lr': FLAGS.predictor_lr,
        'encoder_wd': FLAGS.encoder_wd,
        'lens_wd': FLAGS.lens_wd,
        'predictor_wd': FLAGS.predictor_wd,
        'train_families': FLAGS.train_families,
        'lens_train_samples': FLAGS.lens_train_samples,
        'first_test_family': FLAGS.first_test_family,
        'last_test_family': FLAGS.last_test_family,
        'lens_shuffle_seed': FLAGS.lens_shuffle_seed,
        'lens_sample_random_state': FLAGS.lens_sample_random_state,
        'knn_shuffle_seed': FLAGS.knn_shuffle_seed,
        'knn_sample_random_state': FLAGS.knn_sample_random_state,
        'random_key': FLAGS.random_key,
        'use_transformer': FLAGS.use_transformer,
        'use_bert': FLAGS.use_bert,
        'restore_transformer_dir': FLAGS.restore_transformer_dir,
        'gcs_bucket': FLAGS.gcs_bucket,
        'data_partitions_dirpath': FLAGS.data_partitions_dirpath,
        'results_save_dir': FLAGS.results_save_dir,
        'load_model': FLAGS.load_model,
        'load_model_dir': FLAGS.load_model_dir,
        'load_model_step': FLAGS.load_model_step,
        'save_model': FLAGS.save_model,
        'save_model_dir': FLAGS.save_model_dir
    }

    gcsfs = GCSFS(FLAGS.gcs_bucket)

    print(datum)
    df = pd.DataFrame([datum])
    with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'),
                    'w') as gcs_file:
        df.to_csv(gcs_file, index=False)

    knn_train_samples_ = [1, 5, 10, 50]

    num_families = len(family_ids)
    loss_fn_kwargs = {'num_classes': num_families}

    lens_knn_train_family_accessions = []
    for _ in range(1, FLAGS.train_families + 1):
        family_name = 'PF%05d' % _
        lens_knn_train_family_accessions.append(family_name)

    knn_test_family_accessions = []
    for _ in range(FLAGS.first_test_family, FLAGS.last_test_family + 1):
        family_name = 'PF%05d' % _
        knn_test_family_accessions.append(family_name)

    encoder_fn = encoder_fn_name_to_fn(FLAGS.encoder_fn_name)
    encoder_fn_kwargs = json.load(
        open(
            resource_filename(
                'contextual_lenses.resources',
                os.path.join('encoder_fn_kwargs_resources',
                             FLAGS.encoder_fn_kwargs_path + '.json'))))

    reduce_fn = reduce_fn_name_to_fn(FLAGS.reduce_fn_name)
    reduce_fn_kwargs = json.load(
        open(
            resource_filename(
                'contextual_lenses.resources',
                os.path.join('reduce_fn_kwargs_resources',
                             FLAGS.reduce_fn_kwargs_path + '.json'))))

    layers, trainable_encoder = architecture_to_layers(FLAGS.encoder_fn_name,
                                                       FLAGS.reduce_fn_name)

    embedding_model = create_model(
        use_transformer=FLAGS.use_transformer,
        use_bert=FLAGS.use_bert,
        restore_transformer_dir=FLAGS.restore_transformer_dir,
        encoder_fn=encoder_fn,
        encoder_fn_kwargs=encoder_fn_kwargs,
        reduce_fn=reduce_fn,
        reduce_fn_kwargs=reduce_fn_kwargs,
        layers=layers,
        output='embedding')

    datum.update(
        measure_nearest_neighbor_performance(
            accuracy_label=
            'train_knn_accuracy_untrained_lens_1_knn_train_samples',
            encoder=embedding_model,
            family_accessions=lens_knn_train_family_accessions,
            batch_size=FLAGS.knn_batch_size,
            train_samples=1,
            shuffle_seed=FLAGS.knn_shuffle_seed,
            sample_random_state=FLAGS.knn_sample_random_state))

    for knn_train_samples in knn_train_samples_:

        datum.update(
            measure_nearest_neighbor_performance(
                accuracy_label='test_knn_accuracy_untrained_lens_' +
                str(knn_train_samples) + '_knn_train_samples',
                encoder=embedding_model,
                family_accessions=knn_test_family_accessions,
                batch_size=FLAGS.knn_batch_size,
                train_samples=knn_train_samples,
                shuffle_seed=FLAGS.knn_shuffle_seed,
                sample_random_state=FLAGS.knn_sample_random_state))

    encoder_fn_params = None
    reduce_fn_params = None
    predict_fn_params = None

    model = create_model(use_transformer=FLAGS.use_transformer,
                         use_bert=FLAGS.use_bert,
                         restore_transformer_dir=FLAGS.restore_transformer_dir,
                         encoder_fn=encoder_fn,
                         encoder_fn_kwargs=encoder_fn_kwargs,
                         reduce_fn=reduce_fn,
                         reduce_fn_kwargs=reduce_fn_kwargs,
                         layers=layers,
                         output='prediction',
                         encoder_fn_params=encoder_fn_params,
                         reduce_fn_params=reduce_fn_params,
                         predict_fn_params=predict_fn_params)

    optimizer = create_optimizer(
        model=model,
        learning_rate=[FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr],
        weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd],
        layers=layers)

    if FLAGS.load_model:
        optimizer = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(
            'gs://' + FLAGS.gcs_bucket, FLAGS.load_model_dir),
                                                   target=optimizer,
                                                   step=FLAGS.load_model_step)

        trained_params = optimizer.target.params
        embedding_model = set_model_parameters(model=embedding_model,
                                               params=trained_params)

    if FLAGS.save_model:
        checkpoints.save_checkpoint(ckpt_dir=os.path.join(
            'gs://' + FLAGS.gcs_bucket, FLAGS.save_model_dir),
                                    target=optimizer,
                                    step=FLAGS.load_model_step)

    for i in range(FLAGS.measurements):

        train_batches, train_indexes = create_pfam_batches(
            family_accessions=lens_knn_train_family_accessions,
            batch_size=FLAGS.lens_batch_size,
            samples=FLAGS.lens_train_samples,
            epochs=measurement_epochs,
            drop_remainder=True,
            shuffle_seed=FLAGS.lens_shuffle_seed + i,
            sample_random_state=FLAGS.lens_sample_random_state)

        optimizer = train(
            model=optimizer.target,
            train_data=train_batches,
            loss_fn=cross_entropy_loss,
            loss_fn_kwargs=loss_fn_kwargs,
            learning_rate=[
                FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr
            ],
            weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd],
            layers=layers)

        results, preds = pfam_evaluate(
            predict_fn=optimizer.target,
            test_family_accessions=lens_knn_train_family_accessions,
            title=None,
            loss_fn_kwargs=loss_fn_kwargs,
            batch_size=FLAGS.lens_batch_size,
            data_partitions_dirpath=FLAGS.data_partitions_dirpath,
            gcs_bucket=FLAGS.gcs_bucket)

        lens_accuracy = results['accuracy']
        datum['lens_accuracy' + '_measurement_' + str(i)] = lens_accuracy

        lens_cross_entropy = float(results['cross_entropy'])
        datum['lens_cross_entropy' + '_measurement_' +
              str(i)] = lens_cross_entropy

        trained_params = optimizer.target.params
        embedding_model = set_model_parameters(model=embedding_model,
                                               params=trained_params)

        datum.update(
            measure_nearest_neighbor_performance(
                accuracy_label=
                'train_knn_accuracy_trained_lens_1_knn_train_samples' +
                '_measurement_' + str(i),
                encoder=embedding_model,
                family_accessions=lens_knn_train_family_accessions,
                batch_size=FLAGS.knn_batch_size,
                train_samples=1,
                shuffle_seed=FLAGS.knn_shuffle_seed,
                sample_random_state=FLAGS.knn_sample_random_state))

        for knn_train_samples in knn_train_samples_:

            datum.update(
                measure_nearest_neighbor_performance(
                    accuracy_label='test_knn_accuracy_trained_lens_' +
                    str(knn_train_samples) + '_knn_train_samples' +
                    '_measurement_' + str(i),
                    encoder=embedding_model,
                    family_accessions=knn_test_family_accessions,
                    batch_size=FLAGS.knn_batch_size,
                    train_samples=knn_train_samples,
                    shuffle_seed=FLAGS.knn_shuffle_seed,
                    sample_random_state=FLAGS.knn_sample_random_state))

    print(datum)
    df = pd.DataFrame([datum])
    with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'),
                    'w') as gcs_file:
        df.to_csv(gcs_file, index=False)

    if FLAGS.save_model:
        checkpoints.save_checkpoint(ckpt_dir=os.path.join(
            'gs://' + FLAGS.gcs_bucket, FLAGS.save_model_dir),
                                    target=optimizer,
                                    step=FLAGS.load_model_step + FLAGS.epochs)
Пример #10
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    base_train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    base_eval_config = base_train_config.replace(deterministic=True,
                                                 train_vq=False)
    base_predict_config = base_train_config.replace(shift=False,
                                                    deterministic=True,
                                                    train_vq=False,
                                                    decode=True)
    train_config = models.LatentTransformerConfig(
        base_cfg=base_train_config,
        latent_vocab_size=FLAGS.latent_vocab_size,
        c=FLAGS.c,
        train_vq=True,
        commitment_cost_vq=FLAGS.commitment_cost_vq)
    eval_config = models.LatentTransformerConfig(
        base_cfg=base_eval_config,
        latent_vocab_size=FLAGS.latent_vocab_size,
        c=FLAGS.c,
        train_vq=True,
        commitment_cost_vq=FLAGS.commitment_cost_vq)
    predict_config = models.LatentTransformerConfig(
        base_cfg=base_predict_config,
        latent_vocab_size=FLAGS.latent_vocab_size,
        c=FLAGS.c,
        train_vq=True,
        commitment_cost_vq=FLAGS.commitment_cost_vq)

    # Latent Predictor.
    lp_train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=FLAGS.latent_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    lp_eval_config = lp_train_config.replace(deterministic=True)
    lp_predict_config = lp_train_config.replace(shift=False,
                                                deterministic=True,
                                                decode=True)

    rng = jax.random.PRNGKey(0)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.LatentProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))
    lp_m = models.ProgramTransformer(lp_eval_config)
    lp_initial_variables = jax.jit(lp_m.init)(init_rng,
                                              jnp.ones(io_shape, jnp.float32),
                                              jnp.ones(io_shape, jnp.float32),
                                              jnp.ones(program_shape,
                                                       jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])
    lp_optimizer = optimizer_def.create(lp_initial_variables['params'])

    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=initial_variables['vqvae'],
                       lp_optimizer=lp_optimizer)
    # Don't keep a copy of the initial model.
    del initial_variables, lp_initial_variables

    train_rngs = jax.random.split(rng, jax.local_device_count())

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        state = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state)
        # Grab last step.
        start_step = int(state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        bos_token=bos_token,
        eos_token=eos_token,
        learning_rate_fn=learning_rate_fn,
        config=train_config,
        lp_config=lp_train_config),
                            axis_name='batch',
                            static_broadcasted_argnums=(4, ))
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             bos_token=bos_token,
                                             eos_token=eos_token,
                                             config=eval_config,
                                             lp_config=lp_eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        lp_config=lp_predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        bos_token=bos_token,
        eos_token=eos_token,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        lp_config=lp_predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(5, ))

    metrics_all = []
    latent_metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        state, metrics, latent_metrics, train_rngs = p_train_step(
            state,
            inputs,
            outputs,
            programs,
            step <= FLAGS.num_pretrain_steps,
            train_rng=train_rngs)
        metrics, latent_metrics = jax.tree_map(np.array,
                                               (metrics, latent_metrics))
        metrics_all.append(metrics)
        latent_metrics_all.append(latent_metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(state), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        latent_metrics_all = common_utils.get_metrics(latent_metrics_all)
        metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary.update(
            jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums))

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []
        latent_metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        latent_eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)
            all_metrics = p_eval_step(state, inputs, outputs, programs)
            metrics, latent_metrics = jax.tree_map(np.array, all_metrics)
            eval_metrics.append(metrics)
            latent_eval_metrics.append(latent_metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary.update(
            jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums))

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 50, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions, latent_predictions = [], [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache, lp_cache = p_init_cache(inputs, outputs, programs)
                predicted, latent_predicted = p_pred_step(
                    state, inputs, outputs, cache, lp_cache, beam_size)
                predicted, latent_predicted = map(
                    tohost, (predicted, latent_predicted))
                inputs, outputs, programs = map(tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_idx, p_score = eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')
                    latent_predictions.append(' '.join(
                        list(np.array(latent_predicted[i,
                                                       p_idx]).astype(str))))

            all_pred_acc, all_pred_denominator = per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n'
                        f'latent_predicted: {latent_predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()
Пример #11
0
def main(argv):
    tf.config.experimental.set_visible_devices([], 'GPU')
    del argv
    logging.info('*** Starting experiment')
    gin_configs = FLAGS.gin_configs

    logging.info('*** Loading Gin configs from: %s', str(gin_configs))
    gin.parse_config_files_and_bindings(config_files=gin_configs,
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=True)

    # Load configurations.
    exp_config = configs.ExperimentConfig()
    model_config = configs.ModelConfig()
    train_config = configs.TrainConfig()

    # Get directory information.
    exp_dir = gpath.GPath(FLAGS.base_folder)
    if exp_config.subname:
        exp_dir = exp_dir / exp_config.subname
    summary_dir = exp_dir / 'summaries' / 'train'
    checkpoint_dir = exp_dir / 'checkpoints'

    # Log and create directories if this is the main host.
    if jax.process_index() == 0:
        logging.info('exp_dir = %s', exp_dir)
        if not exp_dir.exists():
            exp_dir.mkdir(parents=True, exist_ok=True)

        logging.info('summary_dir = %s', summary_dir)
        if not summary_dir.exists():
            summary_dir.mkdir(parents=True, exist_ok=True)

        logging.info('checkpoint_dir = %s', checkpoint_dir)
        if not checkpoint_dir.exists():
            checkpoint_dir.mkdir(parents=True, exist_ok=True)

        config_str = gin.operative_config_str()
        logging.info('Configuration: \n%s', config_str)
        with (exp_dir / 'config.gin').open('w') as f:
            f.write(config_str)

    logging.info('Starting host %d. There are %d hosts : %s',
                 jax.process_index(), jax.process_count(),
                 str(jax.process_indexs()))
    logging.info('Found %d accelerator devices: %s.', jax.local_device_count(),
                 str(jax.local_devices()))
    logging.info('Found %d total devices: %s.', jax.device_count(),
                 str(jax.devices()))

    rng = random.PRNGKey(exp_config.random_seed)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(exp_config.random_seed + jax.process_index())

    if train_config.batch_size % jax.device_count() != 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices.')

    devices = jax.local_devices()
    datasource_spec = exp_config.datasource_spec
    if datasource_spec is None:
        datasource_spec = {
            'type': exp_config.datasource_type,
            'data_dir': FLAGS.data_dir,
        }
    logging.info('Creating datasource: %s', datasource_spec)
    datasource = datasets.from_config(
        datasource_spec,
        image_scale=exp_config.image_scale,
        use_appearance_id=model_config.use_appearance_metadata,
        use_camera_id=model_config.use_camera_metadata,
        use_warp_id=model_config.use_warp,
        use_time=model_config.warp_metadata_encoder_type == 'time',
        random_seed=exp_config.random_seed,
        **exp_config.datasource_kwargs)
    train_iter = datasource.create_iterator(
        datasource.train_ids,
        flatten=True,
        shuffle=True,
        batch_size=train_config.batch_size,
        prefetch_size=3,
        shuffle_buffer_size=train_config.shuffle_buffer_size,
        devices=devices,
    )

    points_iter = None
    if train_config.use_background_loss:
        points = datasource.load_points(shuffle=True)
        points_batch_size = min(
            len(points),
            len(devices) * train_config.background_points_batch_size)
        points_batch_size -= points_batch_size % len(devices)
        points_dataset = tf.data.Dataset.from_tensor_slices(points)
        points_iter = datasets.iterator_from_dataset(
            points_dataset,
            batch_size=points_batch_size,
            prefetch_size=3,
            devices=devices)

    learning_rate_sched = schedules.from_config(train_config.lr_schedule)
    warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
    time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule)
    elastic_loss_weight_sched = schedules.from_config(
        train_config.elastic_loss_weight_schedule)

    rng, key = random.split(rng)
    params = {}
    model, params['model'] = models.construct_nerf(
        key,
        model_config,
        batch_size=train_config.batch_size,
        appearance_ids=datasource.appearance_ids,
        camera_ids=datasource.camera_ids,
        warp_ids=datasource.warp_ids,
        near=datasource.near,
        far=datasource.far,
        use_warp_jacobian=train_config.use_elastic_loss,
        use_weights=train_config.use_elastic_loss)

    optimizer_def = optim.Adam(learning_rate_sched(0))
    optimizer = optimizer_def.create(params)
    state = model_utils.TrainState(optimizer=optimizer,
                                   warp_alpha=warp_alpha_sched(0),
                                   time_alpha=time_alpha_sched(0))
    scalar_params = training.ScalarParams(
        learning_rate=learning_rate_sched(0),
        elastic_loss_weight=elastic_loss_weight_sched(0),
        warp_reg_loss_weight=train_config.warp_reg_loss_weight,
        warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,
        warp_reg_loss_scale=train_config.warp_reg_loss_scale,
        background_loss_weight=train_config.background_loss_weight)
    state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = state.optimizer.state.step + 1
    state = jax_utils.replicate(state, devices=devices)
    del params

    logging.info('Initializing models')

    summary_writer = None
    if jax.process_index() == 0:
        summary_writer = tensorboard.SummaryWriter(str(summary_dir))
        summary_writer.text('gin/train',
                            textdata=gin.config.markdown(config_str),
                            step=0)

    train_step = functools.partial(
        training.train_step,
        model,
        elastic_reduce_method=train_config.elastic_reduce_method,
        elastic_loss_type=train_config.elastic_loss_type,
        use_elastic_loss=train_config.use_elastic_loss,
        use_background_loss=train_config.use_background_loss,
        use_warp_reg_loss=train_config.use_warp_reg_loss,
    )
    ptrain_step = jax.pmap(
        train_step,
        axis_name='batch',
        devices=devices,
        # rng_key, state, batch, scalar_params.
        in_axes=(0, 0, 0, None),
        # Treat use_elastic_loss as compile-time static.
        donate_argnums=(2, ),  # Donate the 'batch' argument.
    )

    if devices:
        n_local_devices = len(devices)
    else:
        n_local_devices = jax.local_device_count()

    logging.info('Starting training')
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)
    time_tracker = utils.TimeTracker()
    time_tracker.tic('data', 'total')
    for step, batch in zip(range(init_step, train_config.max_steps + 1),
                           train_iter):
        if points_iter is not None:
            batch['background_points'] = next(points_iter)
        time_tracker.toc('data')
        # pytype: disable=attribute-error
        scalar_params = scalar_params.replace(
            learning_rate=learning_rate_sched(step),
            elastic_loss_weight=elastic_loss_weight_sched(step))
        warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)
        time_alpha = jax_utils.replicate(time_alpha_sched(step), devices)
        state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha)

        with time_tracker.record_time('train_step'):
            state, stats, keys = ptrain_step(keys, state, batch, scalar_params)
            time_tracker.toc('total')

        if step % train_config.print_every == 0 and jax.process_index() == 0:
            logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s',
                         step, warp_alpha_sched(step), time_alpha_sched(step),
                         time_tracker.summary_str('last'))
            coarse_metrics_str = ', '.join(
                [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])
            fine_metrics_str = ', '.join(
                [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])
            logging.info('\tcoarse metrics: %s', coarse_metrics_str)
            if 'fine' in stats:
                logging.info('\tfine metrics: %s', fine_metrics_str)

        if step % train_config.save_every == 0 and jax.process_index() == 0:
            training.save_checkpoint(checkpoint_dir, state)

        if step % train_config.log_every == 0 and jax.process_index() == 0:
            # Only log via host 0.
            _log_to_tensorboard(summary_writer,
                                jax_utils.unreplicate(state),
                                scalar_params,
                                jax_utils.unreplicate(stats),
                                time_dict=time_tracker.summary('mean'))
            time_tracker.reset()

        if step % train_config.histogram_every == 0 and jax.process_index(
        ) == 0:
            _log_histograms(summary_writer, model,
                            jax_utils.unreplicate(state))

        time_tracker.tic('data', 'total')

    if train_config.max_steps % train_config.save_every != 0:
        training.save_checkpoint(checkpoint_dir, state)
Пример #12
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)
Пример #13
0
def main(_):
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0
    save_dir = FLAGS.model_dir if FLAGS.save_dir is None else FLAGS.save_dir
    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine = jax.random.split(rng, 3)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir, FLAGS.config, num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    train_items, val_items, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    step = int(state.step)
    state = jax.device_put_replicated(state, jax.local_devices())

    # TODO: TPU Colab breaks without message if this is a list
    # a list is preferred bc tqdm can show an ETA
    render_dict = {
        "train": zip(range(train_items), train_ds),
        "val": zip(range(val_items), val_ds),
        "test": zip(range(test_items), test_ds),
        "poses": zip(range(num_poses), render_ds),
    }
    render_poses = render_dict[FLAGS.render_video_set]

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    if FLAGS.render_video:
        rgb_list = []
        disp_list = []
        losses = []
        for _, inputs in tqdm(render_poses, desc="Rays render"):
            rays, padding = prepare_render_data(inputs["rays"].numpy())
            preds, *_ = p_eval_step(state, rays)
            preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds)
            rgb_list.append(preds["rgb"])
            disp_list.append(preds["disp"])

            if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render":
                loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                losses.append(loss)

        if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render":
            loss = np.mean(losses)
            logging.info("Loss %.5f", loss)
            logging.info("PSNR %.5f", psnr_fn(loss))
        gen_video(save_dir, np.stack(rgb_list), "rgb", r_hwf, step)
        disp = np.stack(disp_list)
        gen_video(save_dir,
                  disp_post(disp, FLAGS.config),
                  "disp",
                  r_hwf,
                  step,
                  ch=1)

    if FLAGS.render_testset:
        test_losses = []
        for idx, inputs in tqdm(zip(range(test_items), test_ds),
                                desc="Test render"):
            rays, padding = prepare_render_data(inputs["rays"].numpy())
            preds, *_ = p_eval_step(state, rays)
            preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds)
            save_test_imgs(save_dir, preds["rgb"], r_hwf, step, idx)

            if FLAGS.config.render_factor == 1:
                loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                test_losses.append(loss)
        if FLAGS.config.render_factor == 1:
            loss = np.mean(test_losses)
            logging.info("Loss %.5f", loss)
            logging.info("PSNR %.5f", psnr_fn(loss))
Пример #14
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

  image_size = 224

  batch_size = FLAGS.batch_size
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  platform = jax.local_devices()[0].platform

  dynamic_scale = None
  if FLAGS.half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
      input_dtype = tf.bfloat16
    else:
      model_dtype = jnp.float16
      input_dtype = tf.float16
      dynamic_scale = optim.DynamicScale()
  else:
    model_dtype = jnp.float32
    input_dtype = tf.float32

  train_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=True,
      cache=FLAGS.cache)
  eval_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=False,
      cache=FLAGS.cache)

  # Create the hyperparameter object
  if FLAGS.hparams_config_dict:
    # In this case, there are multiple training configs defined in the config
    # dict, so we pull out the one this training run should use.
    if 'configs' in FLAGS.hparams_config_dict:
      hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx]
    else:
      hparams_config_dict = FLAGS.hparams_config_dict
    hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams,
        hparams_config_dict)
  else:
    raise ValueError('Please provide a base config dict.')

  os_hparams_utils.write_hparams_to_file_with_host_id_check(
      hparams, FLAGS.model_dir)

  # get num_epochs from hparam instead of FLAGS
  num_epochs = hparams.lr_scheduler.num_epochs
  steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
  steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * num_epochs

  # Estimate compute / memory costs
  if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost:
    estimate_compute_and_memory_cost(
        image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams)
    logging.info('Writing training HLO and estimating compute/memory costs.')

  rng = random.PRNGKey(hparams.seed)
  model, variables = imagenet_train_utils.create_model(
      rng,
      device_batch_size,
      image_size,
      model_dtype,
      hparams=hparams.model_hparams,
      train=True,
      is_teacher=hparams.is_teacher)

  # pylint: disable=g-long-lambda
  if hparams.teacher_model == 'resnet50-8bit':
    teacher_config = w8a8auto_paper_config()
    teacher_hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config)
    teacher_model, _ = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=teacher_hparams.model_hparams,
        train=False,
        is_teacher=True)  # teacher model does not need to be trainable
    # Directory where checkpoints are saved
    ckpt_model_dir = FLAGS.resnet508b_ckpt_path
    # will restore to best checkpoint
    state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None)
    teacher_variables = {'params': state_load['optimizer']['target']}
    teacher_variables.update(state_load['model_state'])
    # create a dictionary for better argument passing
    teacher = {
        'model':
            lambda var, img, labels: jax.nn.softmax(
                teacher_model.apply(var, img)),
        'variables':
            teacher_variables,
    }
  elif hparams.teacher_model == 'labels':
    teacher = {
        'model':
            lambda var, img, labels: common_utils.onehot(
                labels, num_classes=1000),
        'variables': {},  # no need of variables in this case
    }
  else:
    raise ValueError('The specified teacher model is not supported.')

  model_state, params = variables.pop('params')
  if hparams.optimizer == 'sgd':
    optimizer = optim.Momentum(
        beta=hparams.momentum, nesterov=True).create(params)
  elif hparams.optimizer == 'adam':
    optimizer = optim.Adam(
        beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params)
  else:
    raise ValueError('Optimizer type is not supported.')
  state = imagenet_train_utils.TrainState(
      step=0,
      optimizer=optimizer,
      model_state=model_state,
      dynamic_scale=dynamic_scale)
  del params, model_state  # do not keep a copy of the initial model

  state = restore_checkpoint(state)
  step_offset = int(state.step)  # step_offset > 0 if restarting from checkpoint
  state = jax_utils.replicate(state)

  base_learning_rate = hparams.base_learning_rate * batch_size / 256.
  learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                             steps_per_epoch,
                                             hparams.lr_scheduler,
                                             batch_size)

  p_train_step = jax.pmap(
      functools.partial(
          imagenet_train_utils.train_step,
          model,
          learning_rate_fn=learning_rate_fn,
          teacher=teacher),
      axis_name='batch',
      static_broadcasted_argnums=(2, 3, 4))
  p_eval_step = jax.pmap(
      functools.partial(imagenet_train_utils.eval_step, model),
      axis_name='batch',
      static_broadcasted_argnums=(2,))

  epoch_metrics = []
  state_dict_summary_all = []
  state_dict_keys = _get_state_dict_keys_from_flags()
  t_loop_start = time.time()
  last_log_step = 0
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch:
      break
    update_bounds = train_utils.should_update_bounds(
        hparams.activation_bound_update_freq,
        hparams.activation_bound_start_step, step)
    # and pass the result bool value to p_train_step
    # The function should take hparams.weight_quant_start_step as inputs
    quantize_weights = train_utils.should_quantize_weights(
        hparams.weight_quant_start_step, step // steps_per_epoch)
    state, metrics = p_train_step(state, batch, hparams, update_bounds,
                                  quantize_weights)

    state_dict_summary = summary_utils.get_state_dict_summary(
        state.model_state, state_dict_keys)
    state_dict_summary_all.append(state_dict_summary)

    epoch_metrics.append(metrics)
    def should_log(step):
      epoch_no = step // steps_per_epoch
      step_in_epoch = step - epoch_no * steps_per_epoch
      do_log = False
      do_log = do_log or (step + 1 == num_steps)  # log at the end
      end_of_train = step / num_steps > 0.9
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 4) == 0) and not end_of_train)
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 16) == 0) and end_of_train)
      return do_log

    if should_log(step):
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start)
      last_log_step = step
      t_loop_start = time.time()

      # Write to TensorBoard
      state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all)
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

        if FLAGS.write_summary:
          summary_utils.write_state_dict_summaries_to_tb(
              state_dict_summary_all, summary_writer,
              FLAGS.state_dict_summary_freq, step)

      state_dict_summary_all = []
      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = imagenet_train_utils.sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch, quantize_weights)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = imagenet_train_utils.sync_batch_stats(state)
      save_checkpoint(state)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Пример #15
0
def main(unused_argv):
  rng = random.PRNGKey(20200823)

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")

  dataset = datasets.get_dataset("test", FLAGS)
  rng, key = random.split(rng)
  model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates "speckle" artifacts.
  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, False), axis_name="batch")

  # pmap over only the data input.
  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),
      donate_argnums=3,
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      "path_renders" if FLAGS.render_path else "test_preds")
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, "eval"))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnrs = []
    ssims = []
    if not FLAGS.eval_once:
      showcase_index = np.random.randint(0, dataset.size)
    for idx in range(dataset.size):
      print(f"Evaluating {idx+1}/{dataset.size}")
      batch = next(dataset)
      pred_color, pred_disp, pred_acc = utils.render_image(
          functools.partial(render_pfn, state.optimizer.target),
          batch["rays"],
          rng,
          FLAGS.dataset == "llff",
          chunk=FLAGS.chunk)
      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_disp = pred_disp
        showcase_acc = pred_acc
        if not FLAGS.render_path:
          showcase_gt = batch["pixels"]
      if not FLAGS.render_path:
        psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, batch["pixels"])
        print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
        psnrs.append(float(psnr))
        ssims.append(float(ssim))
      if FLAGS.save_output:
        utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
        utils.save_img(pred_disp[Ellipsis, 0],
                       path.join(out_dir, "disp_{:03d}.png".format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image("pred_color", showcase_color, step)
      summary_writer.image("pred_disp", showcase_disp, step)
      summary_writer.image("pred_acc", showcase_acc, step)
      if not FLAGS.render_path:
        summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
        summary_writer.scalar("ssim", np.mean(np.array(ssims)), step)
        summary_writer.image("target", showcase_gt, step)
    if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in psnrs]))
      with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in ssims]))
      with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(psnrs))))
      with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(ssims))))
    if FLAGS.eval_once:
      break
    if int(step) >= FLAGS.max_steps:
      break
    last_step = step
Пример #16
0
def train_and_evaluate(config,
                       workdir,
                       vocab_filepath,
                       random_seed = 0):
  """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.
    random_seed: Random number generator seed.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
  n_processes = jax.process_count()  # Number of processes
  n_devices = jax.local_device_count()  # Number of local devices per process

  if config.train_batch_size % (n_processes * n_devices) > 0:
    raise ValueError(
        "Training batch size must be divisible by the total number of devices, "
        "but training batch size = %d, while total number of devices = %d "
        "(%d processes, each with %d devices)" %
        (config.train_batch_size, n_processes * n_devices, n_processes,
         n_devices))

  if config.eval_batch_size % (n_processes * n_devices) > 0:
    raise ValueError(
        "Eval batch size must be divisible by the total number of devices, "
        "but eval batch size = %d, while total number of devices = %d "
        "(%d processes, each with %d devices)" %
        (config.eval_batch_size, n_processes * n_devices, n_processes,
         n_devices))

  per_process_train_batch_size = config.train_batch_size // n_processes
  per_process_eval_batch_size = config.eval_batch_size // n_processes

  if jax.process_index() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "train"))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "eval"))
  else:
    train_summary_writer = None
    eval_summary_writer = None

  rng = random.PRNGKey(random_seed)
  rng, init_rng = random.split(rng)

  tokenizer = spm.SentencePieceProcessor()
  tokenizer.Load(vocab_filepath)

  ds_info = tfds.builder(config.dataset_name).info
  num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

  num_train_steps = int(num_train_examples * config.num_train_epochs //
                        config.train_batch_size)
  num_warmup_steps = int(config.warmup_proportion * num_train_steps)
  # Round up evaluation frequency to power of 10.
  eval_frequency = int(
      math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

  is_regression_task = config.dataset_name == "glue/stsb"

  num_classes = (1 if is_regression_task else
                 ds_info.features["label"].num_classes)

  with config.unlocked():
    config.vocab_size = tokenizer.GetPieceSize()

  frozen_config = ml_collections.FrozenConfigDict(config)
  model = models.SequenceClassificationModel(
      config=frozen_config, n_classes=num_classes)

  params = _init_params(model, init_rng, config)

  optimizer = _create_adam_optimizer(config.learning_rate, params)

  # In case current job restarts, ensure that we continue from where we left
  # off.
  optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
  start_step = int(optimizer.state.step)

  # Otherwise, try to restore optimizer and model state from config checkpoint.
  if (start_step == 0 and "init_checkpoint_dir" in config and
      config.init_checkpoint_dir):
    optimizer = _restore_pretrained_model(optimizer, params, config)

  # We access model state only from optimizer via optimizer.target.
  del params

  optimizer = jax_utils.replicate(optimizer)

  if is_regression_task:
    compute_stats = functools.partial(_compute_regression_stats, model=model)
  else:
    compute_stats = functools.partial(
        _compute_classification_stats, model=model)

  learning_rate_fn = train_utils.create_learning_rate_scheduler(
      factors="constant * linear_warmup * linear_decay",
      base_learning_rate=config.learning_rate,
      warmup_steps=num_warmup_steps,
      decay_steps=num_train_steps - num_warmup_steps,
  )

  glue_inputs = functools.partial(
      input_pipeline.glue_inputs,
      dataset_name=config.dataset_name,
      max_seq_length=config.max_seq_length,
      tokenizer=tokenizer)
  train_iter = glue_inputs(
      split=tfds.Split.TRAIN,
      batch_size=per_process_train_batch_size,
      training=True)

  if config.dataset_name == "glue/mnli":
    # MNLI contains two validation and test datasets.
    split_suffixes = ["_matched", "_mismatched"]
  else:
    split_suffixes = [""]

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rngs = random.split(rng, n_devices)

  loss_and_metrics_fn = functools.partial(
      _compute_loss_and_metrics, model=model)
  p_train_step = jax.pmap(
      functools.partial(
          train_utils.train_step,
          loss_and_metrics_fn=loss_and_metrics_fn,
          learning_rate_fn=learning_rate_fn),
      axis_name="batch")
  p_eval_step = jax.pmap(
      functools.partial(train_utils.eval_step, metric_fn=compute_stats),
      axis_name="batch")
  eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name,
                                            is_regression_task)

  train_metrics = []
  seconds = 0.0

  logging.info("Starting training loop.")
  logging.info("====================")

  for step, train_batch in zip(range(start_step, num_train_steps), train_iter):
    train_batch = common_utils.shard(train_batch)
    curr_time = time.time()
    optimizer, train_step_metrics, rngs = p_train_step(
        optimizer, train_batch, rng=rngs)
    seconds += time.time() - curr_time
    train_metrics.append(train_step_metrics)

    if ((step > 0 and config.save_checkpoints_steps and
         step % config.save_checkpoints_steps == 0) or
        step == num_train_steps - 1) and jax.process_index() == 0:
      # Save un-replicated optimizer and model state.
      checkpoints.save_checkpoint(
          workdir, jax_utils.unreplicate(optimizer), step, keep=2)

    # Periodic metric handling.
    if step % eval_frequency != 0 and step < num_train_steps - 1:
      continue

    logging.info("Gathering training metrics at step: %d", step)

    train_metrics = common_utils.get_metrics(train_metrics)
    train_summary = {
        "loss":
            jnp.sum(train_metrics["loss"]) /
            jnp.sum(train_metrics["num_labels"]),
        "learning_rate":
            learning_rate_fn(step)
    }
    if not is_regression_task:
      train_summary["accuracy"] = jnp.sum(
          train_metrics["correct_predictions"]) / jnp.sum(
              train_metrics["num_labels"])

    if jax.process_index() == 0:
      assert train_summary_writer
      steps_per_sec = (step - start_step + 1) / seconds
      train_summary_writer.scalar("steps per second", steps_per_sec, step)

      for key, val in train_summary.items():
        train_summary_writer.scalar(key, val, step)
      train_summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    train_metrics = []

    logging.info("Gathering validation metrics at step: %d", step)

    for split_suffix in split_suffixes:
      eval_iter = glue_inputs(
          split=tfds.Split.VALIDATION + split_suffix,
          batch_size=per_process_eval_batch_size,
          training=False)

      eval_metrics = []
      for _, eval_batch in zip(range(config.max_num_eval_steps), eval_iter):
        eval_metrics.append(
            _compute_eval_metrics(p_eval_step, optimizer.target, eval_batch,
                                  n_devices))

      if eval_metrics:
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_summary = eval_metrics_fn(eval_metrics)

        if jax.process_index() == 0:
          assert eval_summary_writer
          for key, val in eval_summary.items():
            eval_summary_writer.scalar(f"{key}{split_suffix}", val, step)
          eval_summary_writer.flush()

      else:
        logging.warning("Error gathering eval_metrics at step: %d", step)
Пример #17
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(workdir)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "train"))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "eval"))

  if config.batch_size % n_devices:
    raise ValueError("Batch size must be divisible by the number of devices")

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=n_devices,
      dataset_name=config.dataset_name,
      eval_dataset_name=config.eval_dataset_name,
      shard_idx=jax.host_id(),
      shard_count=jax.host_count(),
      vocab_path=vocab_path,
      target_vocab_size=config.vocab_size,
      batch_size=config.batch_size,
      max_corpus_chars=config.max_corpus_chars,
      max_length=config.max_target_length,
      max_eval_length=config.max_eval_target_length)
  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = random.PRNGKey(config.seed)
  rng, init_rng = random.split(rng)
  input_shape = (config.batch_size, config.max_target_length)
  target_shape = (config.batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32),
                  jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config,
          label_smoothing=config.label_smoothing),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = random.split(rng, n_devices)

  logging.info("Starting training loop.")
  metrics_all = []
  t_loop_start = time.time()
  for step, batch in zip(range(start_step, config.num_train_steps), train_iter):
    # Shard data to devices and do a training step.
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    # Save a checkpoint on one host after every checkpoint_freq steps.
    if (config.save_checkpoints and step % config.checkpoint_freq == 0 and
        step > 0 and jax.host_id() == 0):
      checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                  step)

    # Periodic metric handling.
    if step % config.eval_frequency != 0 and step > 0:
      continue

    # Training Metrics
    logging.info("Gathering training metrics.")
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop("learning_rate").mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop("denominator")
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary["learning_rate"] = lr
    steps_per_eval = config.eval_frequency if step != 0 else 1
    steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
    t_loop_start = time.time()
    if jax.host_id() == 0:
      train_summary_writer.scalar("steps per second", steps_per_sec, step)
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, step)
      train_summary_writer.flush()
    metrics_all = []
    logging.info("train in step: %d, loss: %.4f", step, summary["loss"])

    # Eval Metrics
    logging.info("Gathering evaluation metrics.")
    t_eval_start = time.time()
    eval_metrics = []
    eval_iter = iter(eval_ds)
    for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
      eval_batch = common_utils.shard(eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop("denominator")
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)
    if jax.host_id() == 0:
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, step)
      eval_summary_writer.flush()
    logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"])
    logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step)

    # Translation and BLEU Score.
    logging.info("Translating evaluation dataset.")
    t_inference_start = time.time()
    sources, references, predictions = [], [], []
    for pred_batch in predict_ds:
      pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
      # Handle final odd-sized batch by padding instead of dropping it.
      cur_pred_batch_size = pred_batch["inputs"].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        pred_batch = jax.tree_map(
            lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
      pred_batch = common_utils.shard(pred_batch)
      cache = p_init_cache(pred_batch["inputs"])
      predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache,
                              eos_id, config.max_predict_length)
      predicted = tohost(predicted)
      inputs = tohost(pred_batch["inputs"])
      targets = tohost(pred_batch["targets"])
      # Iterate through non-padding examples of batch.
      for i, s in enumerate(predicted[:cur_pred_batch_size]):
        sources.append(decode_tokens(inputs[i]))
        references.append(decode_tokens(targets[i]))
        predictions.append(decode_tokens(s))
    logging.info("Translation: %d predictions %d references %d sources.",
                 len(predictions), len(references), len(sources))
    logging.info("Translation time: %.4f s step %d.",
                 time.time() - t_inference_start, step)

    # Calculate BLEU score for translated eval corpus against reference.
    bleu_matches = bleu.bleu_partial(references, predictions)
    all_bleu_matches = per_host_sum_pmap(bleu_matches)
    bleu_score = bleu.complete_bleu(*all_bleu_matches)
    # Save translation samples for tensorboard.
    exemplars = ""
    for n in np.random.choice(np.arange(len(predictions)), 8):
      exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
    if jax.host_id() == 0:
      eval_summary_writer.scalar("bleu", bleu_score, step)
      eval_summary_writer.text("samples", exemplars, step)
      eval_summary_writer.flush()
    logging.info("Translation BLEU Score %.4f", bleu_score)
Пример #18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': CLASS_MAP[FLAGS.task_name],
        'classifier_pool': config.classifier_pool
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    if model_type == 'transformer':
        model = create_model(init_rng, transformer.TransformerEncoder,
                             input_shape, model_kwargs)
    else:
        raise ValueError('Model type not supported')

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            eval_summary['perplexity'] = jnp.clip(jnp.exp(
                eval_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()
Пример #19
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_sparsity", stats.sparsity[0],
                                      step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("train_sparsity_coarse",
                                      stats.sparsity_c[0], step)

                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            (pred_color, pred_disp, pred_acc, pred_features,
             pred_specular) = utils.render_image(functools.partial(
                 render_pfn, eval_variables),
                                                 test_case["rays"],
                                                 keys[0],
                                                 FLAGS.dataset == "llff",
                                                 chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_pred_features", pred_features, step)
                summary_writer.image("test_pred_specular", pred_specular, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
Пример #20
0
def evaluate(config, workdir):
    """Evalution function."""

    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = jax.random.PRNGKey(config.seed)

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, test_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    rng, key = jax.random.split(rng)

    #----------------------------------------------------------------------------
    # Initialize model.
    learning_rate_fn = train_utils.create_learning_rate_fn(config)
    model, state = models.create_train_state(config,
                                             key,
                                             learning_rate_fn=learning_rate_fn,
                                             example_batch=example_batch)

    #----------------------------------------------------------------------------
    # Get the rendering function. Renderig is forced ot be deterministic even if
    # trainin is randomized
    render_pfn = render_utils.get_render_function(model,
                                                  config,
                                                  randomized=False)

    last_step = 0
    out_dir = os.path.join(
        workdir,
        "path_renders" if config.dataset.render_path else "test_preds")

    if not config.eval.eval_once:
        # Prepare Metric Writers
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))

    while True:
        state = checkpoints.restore_checkpoint(workdir, state)
        step = int(state.step)
        if step <= last_step:
            continue

        if config.eval.save_output and (not file_utils.isdir(out_dir)):
            file_utils.makedirs(out_dir)

        psnr_values = []
        ssim_values = []

        if not config.eval.eval_once:
            showcase_index = np.random.randint(0, test_ds.size)

        for idx in range(test_ds.size):
            logging.info("Evaluating [%d / %d].", idx, test_ds.size)
            batch = next(test_ds)
            test_pixels = batch.target_view.rgb
            if test_pixels is not None:
                test_pixels = model_utils.uint2float(test_pixels)

            #-----------------------------------------------------------
            # Render Image
            variables = {"params": state.params}
            pred_color, pred_disp, pred_acc = render_utils.render_image(
                functools.partial(render_pfn, variables),
                batch,
                rng,
                render_utils.normalize_disp(config.dataset.name),
                chunk=config.eval.chunk)

            if jax.process_index() != 0:
                continue

            #-----------------------------------------------------------
            # Get showcase example for logging
            if not config.eval.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                if not config.dataset.render_path:
                    showcase_gt = test_pixels
            #-----------------------------------------------------------
            # If get pixels available, evaluate
            if not config.dataset.render_path:
                psnr = model_utils.compute_psnr(
                    ((pred_color - test_pixels)**2).mean())
                ssim = skmetrics.structural_similarity(
                    pred_color.astype(np.float32),
                    test_pixels.astype(np.float32),
                    win_size=11,
                    multichannel=True,
                    gaussian_weights=True)
                logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")  # pylint: disable=logging-format-interpolation
                psnr_values.append(float(psnr))
                ssim_values.append(float(ssim))

            #-----------------------------------------------------------
            # Save generated image
            if config.eval.save_output:
                model_utils.save_img(
                    pred_color, os.path.join(out_dir,
                                             "{:03d}.png".format(idx)))
                if pred_disp is not None:
                    model_utils.save_img(
                        pred_disp[Ellipsis, 0],
                        os.path.join(out_dir, "disp_{:03d}.png".format(idx)))
        #-----------------------------------------------------------
        if (not config.eval.eval_once) and (jax.process_index() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            if showcase_disp is not None:
                summary_writer.image("pred_disp", showcase_disp, step)
            if showcase_acc is not None:
                summary_writer.image("pred_acc", showcase_acc, step)
            if not config.dataset.render_path:
                summary_writer.scalar("eval_metric/psnr",
                                      np.mean(np.array(psnr_values)), step)
                summary_writer.scalar("eval_metric/ssim",
                                      np.mean(np.array(ssim_values)), step)
                summary_writer.image("target", showcase_gt, step)

        #-----------------------------------------------------------
        # Save the metric to file
        if config.eval.save_output and (not config.dataset.render_path) and (
                jax.process_index() == 0):
            with file_utils.open_file(
                    os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in psnr_values]))
            with file_utils.open_file(
                    os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in ssim_values]))
            with file_utils.open_file(os.path.join(out_dir, "psnr.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(psnr_values))))
            with file_utils.open_file(os.path.join(out_dir, "ssim.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(ssim_values))))
        if config.eval.eval_once:
            break
        if int(step) >= config.train.max_steps:
            break
        last_step = step

    logging.info("Finishing evaluation at step %d", last_step)
Пример #21
0
def main(argv):
    del argv
    logging.info("*** Starting experiment")
    gin_configs = FLAGS.gin_configs

    logging.info("*** Loading Gin configs from: %s", str(gin_configs))
    gin.parse_config_files_and_bindings(config_files=gin_configs,
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=True)

    # Load configurations.
    exp_config = configs.ExperimentConfig()
    model_config = configs.ModelConfig(use_stratified_sampling=False)
    train_config = configs.TrainConfig()
    eval_config = configs.EvalConfig()

    # Get directory information.
    exp_dir = gpath.GPath(FLAGS.exp_dir)
    if exp_config.subname:
        exp_dir = exp_dir / exp_config.subname
    logging.info("\texp_dir = %s", exp_dir)
    if not exp_dir.exists():
        exp_dir.mkdir(parents=True, exist_ok=True)

    summary_dir = exp_dir / "summaries" / "eval"
    logging.info("\tsummary_dir = %s", summary_dir)
    if not summary_dir.exists():
        summary_dir.mkdir(parents=True, exist_ok=True)

    renders_dir = exp_dir / "renders"
    logging.info("\trenders_dir = %s", renders_dir)
    if not renders_dir.exists():
        renders_dir.mkdir(parents=True, exist_ok=True)

    checkpoint_dir = exp_dir / "checkpoints"
    logging.info("\tcheckpoint_dir = %s", checkpoint_dir)

    rng = random.PRNGKey(20200823)

    devices_to_use = jax.devices()
    n_devices = len(
        devices_to_use) if devices_to_use else jax.local_device_count()

    datasource_spec = exp_config.datasource_spec
    if datasource_spec is None:
        datasource_spec = {
            "type": exp_config.datasource_type,
            "data_dir": FLAGS.data_dir,
        }
    logging.info("Creating datasource: %s", datasource_spec)
    datasource = datasets.from_config(
        datasource_spec,
        image_scale=exp_config.image_scale,
        use_appearance_id=model_config.use_appearance_metadata,
        use_camera_id=model_config.use_camera_metadata,
        use_warp_id=model_config.use_warp,
        random_seed=exp_config.random_seed,
    )

    # Get training IDs to evaluate.
    train_eval_ids = utils.strided_subset(datasource.train_ids,
                                          eval_config.num_train_eval)
    train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0)
    val_eval_ids = utils.strided_subset(datasource.val_ids,
                                        eval_config.num_val_eval)
    val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0)

    test_cameras = datasource.load_test_cameras(
        count=eval_config.num_test_eval)
    if test_cameras:
        test_dataset = datasource.create_cameras_dataset(test_cameras)
        test_eval_ids = [f"{x:03d}" for x in range(len(test_cameras))]
        test_eval_iter = datasets.iterator_from_dataset(test_dataset,
                                                        batch_size=0)
    else:
        test_eval_ids = None
        test_eval_iter = None

    rng, key = random.split(rng)
    params = {}
    model, params["model"] = models.nerf(
        key,
        model_config,
        batch_size=eval_config.chunk,
        num_appearance_embeddings=len(datasource.appearance_ids),
        num_camera_embeddings=len(datasource.camera_ids),
        num_warp_embeddings=len(datasource.warp_ids),
        near=datasource.near,
        far=datasource.far,
        use_warp_jacobian=False,
        use_weights=False,
    )

    optimizer_def = optim.Adam(0.0)
    optimizer = optimizer_def.create(params)
    init_state = model_utils.TrainState(optimizer=optimizer, warp_alpha=0.0)
    del params

    def _model_fn(key_0, key_1, params, rays_dict, alpha):
        out = model.apply(
            {"params": params},
            rays_dict,
            warp_alpha=alpha,
            rngs={
                "coarse": key_0,
                "fine": key_1
            },
            mutable=False,
        )
        return jax.lax.all_gather(out, axis_name="batch")

    pmodel_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        _model_fn,
        in_axes=(0, 0, 0, 0, None),  # Only distribute the data input.
        devices=devices_to_use,
        donate_argnums=(3, ),  # Donate the 'rays' argument.
        axis_name="batch",
    )

    render_fn = functools.partial(
        evaluation.render_image,
        model_fn=pmodel_fn,
        device_count=n_devices,
        chunk=eval_config.chunk,
    )

    last_step = 0
    summary_writer = tensorboard.SummaryWriter(str(summary_dir))

    while True:
        if not checkpoint_dir.exists():
            logging.info("No checkpoints yet.")
            time.sleep(10)
            continue

        state = checkpoints.restore_checkpoint(checkpoint_dir, init_state)
        state = jax_utils.replicate(state, devices=devices_to_use)
        step = int(state.optimizer.state.step[0])
        if step <= last_step:
            logging.info("No new checkpoints (%d <= %d).", step, last_step)
            time.sleep(10)
            continue

        save_dir = renders_dir if eval_config.save_output else None
        process_iterator(
            tag="train",
            item_ids=train_eval_ids,
            iterator=train_eval_iter,
            state=state,
            rng=rng,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        process_iterator(
            tag="val",
            item_ids=val_eval_ids,
            iterator=val_eval_iter,
            state=state,
            rng=rng,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        if test_eval_iter:
            process_iterator(
                tag="test",
                item_ids=test_eval_ids,
                iterator=test_eval_iter,
                state=state,
                rng=rng,
                step=step,
                render_fn=render_fn,
                summary_writer=summary_writer,
                save_dir=save_dir,
                datasource=datasource,
            )

        if eval_config.eval_once:
            break
        if step >= train_config.max_steps:
            break
        last_step = step
Пример #22
0
def restore_checkpoint(state, workdir):
  return checkpoints.restore_checkpoint(workdir, state)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    model = create_model(config)
    optimizer = create_optimizer(config, model)
    del model  # don't keep a copy of the initial model

    output_dir = os.path.join(workdir, "pretraining")
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    start_step = int(optimizer.state.step)

    optimizer = optimizer.replicate()

    tokenizer = spm.SentencePieceProcessor()
    if "vocab_file" in config:
        tokenizer.Load(config.vocab_file)
    tokenizer.SetEncodeExtraOptions("")
    # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * cosine_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = train_utils.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_iter = input_pipeline.c4_masked_lm_inputs(
            config.train_batch_size, tokenizer, config.max_seq_length,
            config.max_predictions_per_seq)
        train_step_fn = train_utils.create_train_step(
            compute_pretraining_loss_and_metrics, clip_grad_norm=1.0)

        for step in range(start_step, config.num_train_steps):
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = next(train_iter)
                optimizer, train_state = train_step_fn(optimizer, batch,
                                                       train_state)
                if step % config.save_checkpoints_steps == 0 and jax.host_id(
                ) == 0:
                    checkpoints.save_checkpoint(output_dir,
                                                optimizer.unreplicate(), step)

    if config.do_eval:
        eval_iter = input_pipeline.c4_masked_lm_inputs(
            config.eval_batch_size, tokenizer, config.max_seq_length,
            config.max_predictions_per_seq)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = train_utils.create_eval_fn(compute_pretraining_stats,
                                             sample_feature_name="input_ids")
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            "loss":
            jnp.mean(eval_stats["loss"]),
            "masked_lm_loss":
            jnp.mean(eval_stats["masked_lm_loss"]),
            "next_sentence_loss":
            jnp.mean(eval_stats["next_sentence_loss"]),
            "masked_lm_accuracy":
            jnp.sum(eval_stats["masked_lm_correct"]) /
            jnp.sum(eval_stats["masked_lm_total"]),
            "next_sentence_accuracy":
            jnp.sum(eval_stats["next_sentence_correct"]) /
            jnp.sum(eval_stats["next_sentence_total"]),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f"{name} = {val:.06f}"
            print(line, flush=True)
            logging.info(line)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, "eval_results.txt")
        with gfile.GFile(eval_results_path, "w") as f:
            for line in eval_results:
                f.write(line + "\n")