Esempio n. 1
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in
                         sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    io_string = ''
    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
      io_string += inps[-1] + ' < ' + outs[-1] + ' > '
    return inps, outs, io_string[:-3]  # Remove last separator.

  def decode_program(program):
    """Decode program tokens."""
    program = program[:np.argmax(program == eos_token) + 1].astype(np.int32)
    try:
      p = dsl.decode_program(program, id_token_table)
      return p, p.to_string()
    except:  # pylint: disable=bare-except
      return None, ''  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern, token_id_table, char_id_table)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  base_train_config = models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  base_eval_config = base_train_config.replace(deterministic=True,
                                               train_vq=False)
  base_predict_config = base_train_config.replace(
      shift=False, deterministic=True, train_vq=False, decode=True)
  train_config = models.LatentTransformerConfig(
      base_cfg=base_train_config,
      latent_vocab_size=FLAGS.latent_vocab_size,
      c=FLAGS.c,
      train_vq=True,
      commitment_cost_vq=FLAGS.commitment_cost_vq)
  eval_config = models.LatentTransformerConfig(
      base_cfg=base_eval_config,
      latent_vocab_size=FLAGS.latent_vocab_size,
      c=FLAGS.c,
      train_vq=True,
      commitment_cost_vq=FLAGS.commitment_cost_vq)
  predict_config = models.LatentTransformerConfig(
      base_cfg=base_predict_config,
      latent_vocab_size=FLAGS.latent_vocab_size,
      c=FLAGS.c,
      train_vq=True,
      commitment_cost_vq=FLAGS.commitment_cost_vq)

  # Latent Predictor.
  lp_train_config = models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=FLAGS.latent_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  lp_eval_config = lp_train_config.replace(deterministic=True)
  lp_predict_config = lp_train_config.replace(
      shift=False, deterministic=True, decode=True)

  rng = jax.random.PRNGKey(0)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.LatentProgramTransformer(eval_config)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))
  lp_m = models.ProgramTransformer(lp_eval_config)
  lp_initial_variables = jax.jit(lp_m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  optimizer_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])
  lp_optimizer = optimizer_def.create(lp_initial_variables['params'])

  state = TrainState(step=0,
                     optimizer=optimizer,
                     model_state=initial_variables['vqvae'],
                     lp_optimizer=lp_optimizer)
  # Don't keep a copy of the initial model.
  del initial_variables, lp_initial_variables

  train_rngs = jax.random.split(rng, jax.local_device_count())

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    state = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
        state)
    # Grab last step.
    start_step = int(state.step)
    logging.info('Found model checkpointed at step %d.', start_step)

  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=FLAGS.lr)
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          bos_token=bos_token,
          eos_token=eos_token,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          lp_config=lp_train_config),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          bos_token=bos_token,
          eos_token=eos_token,
          config=eval_config,
          lp_config=lp_eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          lp_config=lp_predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          bos_token=bos_token,
          eos_token=eos_token,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          lp_config=lp_predict_config),
      axis_name='batch',
      static_broadcasted_argnums=(5,))

  metrics_all = []
  latent_metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs = common_utils.shard(next(train_iter))

    state, metrics, latent_metrics, train_rngs = p_train_step(
        state, inputs, outputs, programs, step <= FLAGS.num_pretrain_steps,
        train_rng=train_rngs)
    metrics, latent_metrics = jax.tree_map(np.array, (metrics, latent_metrics))
    metrics_all.append(metrics)
    latent_metrics_all.append(latent_metrics)

    # Save a Checkpoint
    if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or
        step == FLAGS.num_train_steps - 1):
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(state),
            step)

    # Periodic metric handling.
    if not step or step % FLAGS.log_freq != 0:
      continue

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    latent_metrics_all = common_utils.get_metrics(latent_metrics_all)
    metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary.update(jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums))

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f, acc: %.4f',
                   step, summary['loss'], summary['accuracy'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []
    latent_metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()
    eval_metrics = []
    latent_eval_metrics = []
    for batches in eval_ds.as_numpy_iterator():
      inputs, outputs, programs = common_utils.shard(batches)
      all_metrics = p_eval_step(state, inputs, outputs, programs)
      metrics, latent_metrics = jax.tree_map(np.array, all_metrics)
      eval_metrics.append(metrics)
      latent_eval_metrics.append(latent_metrics)

    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)

    latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary.update(jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums))

    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [10, 50, 100]:
      t_inference_start = time.time()
      pred_acc = 0
      pred_denominator = 0

      ios, targets, predictions, latent_predictions = [], [], [], []
      for batches in predict_ds.as_numpy_iterator():
        pred_batch = batches
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = pred_batch[0].shape[0]
        if cur_pred_batch_size % n_devices:
          padded_size = int(
              np.ceil(cur_pred_batch_size / n_devices) * n_devices)
          pred_batch = jax.tree_map(
              lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
        inputs, outputs, programs = common_utils.shard(pred_batch)

        cache, lp_cache = p_init_cache(inputs, outputs, programs)
        predicted, latent_predicted = p_pred_step(state,
                                                  inputs,
                                                  outputs,
                                                  cache,
                                                  lp_cache,
                                                  beam_size)
        predicted, latent_predicted = map(tohost, (predicted, latent_predicted))
        inputs, outputs, programs = map(tohost, (inputs, outputs, programs))

        pred_denominator += programs.shape[0]
        for i, beams in enumerate(predicted):
          inps, outs, io_string = decode_io(inputs[i], outputs[i])
          p, p_idx, p_score = eval_predicted(
              beams, inps, outs,
              parse_beam_fn=lambda x: decode_program(x)[0])
          if p_score >= len(inps):
            pred_acc += 1
          ios.append(io_string)
          targets.append(decode_program(programs[i])[1])
          predictions.append(p.to_string() if p else '')
          latent_predictions.append(
              ' '.join(list(np.array(latent_predicted[i, p_idx]).astype(str))))

      all_pred_acc, all_pred_denominator = per_host_sum_pmap(
          jax.tree_map(np.array, (pred_acc, pred_denominator)))

      # Record beam search results as text summaries.
      message = []
      for n in np.random.choice(np.arange(len(predictions)), 8):
        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                f'predicted: {predictions[n]}\n\n'
                f'latent_predicted: {latent_predictions[n]}\n\n')
        message.append(text)

      # Write to tensorboard.
      if jax.host_id() == 0:
        logging.info('Prediction time (beam %d): %.4f s step %d, score %.4f.',
                     beam_size, time.time() - t_inference_start, step,
                     all_pred_acc / all_pred_denominator)
        summary_writer.scalar('predict/score-{}'.format(beam_size),
                              all_pred_acc / all_pred_denominator, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()
Esempio n. 2
0
def run_experiment(
    model_dir,
    data_dir=None,
    xid=None,
    batch_size_per_device=128,
    eval_frequency=500,
    checkpoint_frequency=10000,
    save_checkpoints=True,
    restore_checkpoint=True,
    num_eval_steps=None,
    epochs=None,
    max_train_steps=1000000,  # 1 million
    max_train_length=512,
    train_summary_frequency=100,
    max_eval_length=None,
    model_cls=models.FlaxLM):
  """Run experiment.

  Args:
    model_dir: Directory to save checkpoints and metrics to.
    data_dir: Directory to load data.
    xid: Optional experiment id.
    batch_size_per_device: Batch size per device.
    eval_frequency: Steps per eval.
    checkpoint_frequency: How often to checkpoint. If None, only checkpoint once
      at end of run.
    save_checkpoints: If True, checkpoints model according to
      checkpoint_frequency
    restore_checkpoint: If True, will restore checkpoint from directory. Useful
      for robustness to preemption.
    num_eval_steps: Number of eval steps to take on eval dataset.
    epochs: Number of train epochs.
    max_train_steps: Stop training after N steps.
    max_train_length: Crop training sequences to this length.
    train_summary_frequency: Frequency to write train metrics.
    max_eval_length: Maximum eval length. Defaults to max_train_length.
    model_cls: Model class to use.

  Returns:
    FlaxLM resulting from running training.
  """
  if xid is not None:
    model_dir = os.path.join(model_dir, '%s_l%s' % (str(xid), max_train_length))
  tf.enable_v2_behavior()
  if jax.host_id() == 0:
    summary_writer = tf_summary.create_file_writer(
        os.path.join(model_dir, 'metrics'), max_queue=1, flush_millis=1000)
    train_summary_writer = logging_lib.ScalarSummary(
        step=None,
        scope='train/',
        enable_tf=True,
        verbose=0)
    eval_summary_writer = logging_lib.ScalarSummary(
        step=None,
        scope='eval/',
        enable_tf=True,
        verbose=0)

  batch_size = batch_size_per_device * jax.local_device_count()
  max_eval_length = max_eval_length or max_train_length
  train_files, test_files = data.get_train_test_files(directory=data_dir)
  train_ds, eval_ds = data.load_dataset(
      train_files=train_files,
      test_files=test_files,
      batch_size=batch_size,
      max_train_length=max_train_length,
      max_eval_length=max_eval_length,
      shuffle_buffer=16384)

  with contextlib.ExitStack() as stack:  # pylint: disable=using-constant-test
    if jax.host_id() == 0:
      # Only need metric writer context manager on host 0.
      stack.enter_context(summary_writer.as_default())
    model = model_cls(domain=data.protein_domain, batch_size=batch_size)

    if restore_checkpoint:
      try:
        model.load_checkpoint(model_dir)
      except ValueError:
        # No checkpoint to load -> raises ValueError.
        pass
    start_step = model.train_step

    train_ds = train_ds.repeat(epochs)
    train_iter = iter(train_ds)
    train_metrics = []
    tick = time.time()

    if jax.host_id() == 0:
      _write_gin_configs(os.path.join(model_dir, 'config.gin'))

    num_evals = 0
    for step, batch in zip(range(start_step, max_train_steps), train_iter):
      batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
      metrics = model.fit_batch(batch)
      train_metrics.append(metrics)

      if jax.host_id() == 0 and ((save_checkpoints and checkpoint_frequency and
                                  step % checkpoint_frequency == 0 and step > 0)
                                 or step == max_train_steps - 1):
        model.save_checkpoint(model_dir)

      if (step + 1) % train_summary_frequency == 0:
        summary = evaluation.combine_metrics(train_metrics)
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
        if jax.host_id() == 0:
          tock = time.time()
          steps_per_sec = eval_frequency / (tock - tick)
          tick = tock
          train_summary_writer('steps per second', steps_per_sec, step)
          for key, val in summary.items():
            if jnp.isnan(val):
              raise ValueError(f'NaN in {key} at step {step}.')
            train_summary_writer(key, val, step)

        # reset metric accumulation for next evaluation cycle.
        train_metrics = []

      if eval_frequency and (step + 1) % eval_frequency == 0:
        eval_summary = evaluation.evaluate(
            model=model, eval_ds=eval_ds, num_eval_steps=num_eval_steps)

        logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss'])
        if jax.host_id() == 0:
          for key, val in eval_summary.items():
            eval_summary_writer(key, val, step)
          tf_summary.flush()
          summary_writer.flush()

          if num_evals == 0:
            # Write out config on first eval.
            _write_gin_configs(os.path.join(model_dir, 'config_after_eval.gin'))
          num_evals += 1

  if jax.host_id() == 0:
    tf_summary.flush()
    summary_writer.close()
    _write_gin_configs(os.path.join(model_dir, 'config_end.gin'))
  return model
Esempio n. 3
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        level="NOTSET",
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(data_args.dataset_name,
                                data_args.dataset_config_name,
                                cache_dir=model_args.cache_dir)

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(extension,
                                data_files=data_files,
                                cache_dir=model_args.cache_dir)

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name,
                                            cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path,
                                            cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    if data_args.line_by_line:
        # When using line_by_line, we just tokenize each nonempty line.
        padding = "max_length" if data_args.pad_to_max_length else False

        def tokenize_function(examples):
            # Remove empty lines
            examples = [
                line for line in examples
                if len(line) > 0 and not line.isspace()
            ]
            return tokenizer(
                examples,
                return_special_tokens_mask=True,
                padding=padding,
                truncation=True,
                max_length=max_seq_length,
            )

        tokenized_datasets = datasets.map(
            tokenize_function,
            input_columns=[text_column_name],
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )

    else:
        # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
        # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
        # efficient when it receives the `special_tokens_mask`.
        def tokenize_function(examples):
            return tokenizer(examples[text_column_name],
                             return_special_tokens_mask=True)

        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )

        # Main data processing function that will concatenate all texts from our dataset and generate chunks of
        # max_seq_length.
        def group_texts(examples):
            # Concatenate all texts.
            concatenated_examples = {
                k: list(chain(*examples[k]))
                for k in examples.keys()
            }
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
            if total_length >= max_seq_length:
                total_length = (total_length //
                                max_seq_length) * max_seq_length
            # Split by chunks of max_len.
            result = {
                k: [
                    t[i:i + max_seq_length]
                    for i in range(0, total_length, max_seq_length)
                ]
                for k, t in concatenated_examples.items()
            }
            return result

        # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
        # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
        # might be slower to preprocess.
        #
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
        # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
        tokenized_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
        )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxAutoModelForMaskedLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        model = FlaxAutoModelForMaskedLM.from_config(config,
                                                     seed=training_args.seed,
                                                     dtype=getattr(
                                                         jnp,
                                                         model_args.dtype))

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    # Note that this mask is specifically adapted for FlaxBERT-like models.
    # For other models, one should correct the layer norm parameter naming
    # accordingly.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss, ignore padded input tokens
            label_mask = jnp.where(labels > 0, 1.0, 0.0)
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])) * label_mask

            # take average
            loss = loss.sum() / label_mask.sum()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss, ignore padded input tokens
        label_mask = jnp.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(
            logits, onehot(labels, logits.shape[-1])) * label_mask

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {
            "loss": loss.sum(),
            "accuracy": accuracy.sum(),
            "normalizer": label_mask.sum()
        }
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        train_samples_idx = jax.random.permutation(
            input_rng, jnp.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples, pad_to_multiple_of=16)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                eval_samples_idx = jnp.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples,
                                                 pad_to_multiple_of=16)

                    # Model forward
                    model_inputs = shard(model_inputs.data)
                    metrics = p_eval_step(state.params, model_inputs)
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
                eval_normalizer = eval_metrics.pop("normalizer")
                eval_metrics = jax.tree_map(lambda x: x / eval_normalizer,
                                            eval_metrics)

                # Update progress bar
                epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        eval_samples_idx = jnp.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size)

        eval_metrics = []
        for _, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples, pad_to_multiple_of=16)

            # Model forward
            model_inputs = shard(model_inputs.data)
            metrics = p_eval_step(state.params, model_inputs)
            eval_metrics.append(metrics)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(),
                                    eval_metrics)
        eval_normalizer = eval_metrics.pop("normalizer")
        eval_metrics = jax.tree_map(lambda x: x / eval_normalizer,
                                    eval_metrics)

        try:
            perplexity = math.exp(eval_metrics["loss"])
        except OverflowError:
            perplexity = float("inf")
        eval_metrics["perplexity"] = perplexity

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Esempio n. 4
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    configure_logger(model_args, training_args)

    # Downloading and loading a dataset from the hub.
    datasets = load_dataset(data_args.dataset_name,
                            data_args.dataset_config_name,
                            cache_dir=model_args.cache_dir)

    if "validation" not in datasets.keys():
        # make sure only "validation" and "train" keys remain"
        datasets = DatasetDict()
        datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=
            f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
            cache_dir=model_args.cache_dir,
        )
        datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=
            f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
            cache_dir=model_args.cache_dir,
        )
    else:
        # make sure only "validation" and "train" keys remain"
        datasets = DatasetDict()
        datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split="validation",
            cache_dir=model_args.cache_dir,
        )
        datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"{data_args.train_split_name}",
            cache_dir=model_args.cache_dir,
        )

    # only normalized-inputs-training is supported
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        do_normalize=True)

    def prepare_dataset(batch):
        # check that all files have the correct sampling rate
        batch["speech"], _ = librosa.load(batch[data_args.speech_file_column],
                                          sr=feature_extractor.sampling_rate)
        return batch

    # load audio files into numpy arrays
    vectorized_datasets = datasets.map(
        prepare_dataset,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=datasets["train"].column_names)

    # filter audio files that are too long
    vectorized_datasets = vectorized_datasets.filter(lambda data: len(data[
        "speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.
                         sampling_rate))

    def normalize(batch):
        return feature_extractor(batch["speech"],
                                 sampling_rate=feature_extractor.sampling_rate)

    # normalize and transform to `BatchFeatures`
    vectorized_datasets = vectorized_datasets.map(
        normalize,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        remove_columns=vectorized_datasets["train"].column_names,
    )

    # pretraining is only supported for "newer" stable layer norm architecture
    # apply_spec_augment has to be True, mask_feature_prob has to be 0.0
    config = Wav2Vec2Config.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        gradient_checkpointing=model_args.gradient_checkpointing,
    )

    if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
        raise ValueError(
            "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
        )

    model = FlaxWav2Vec2ForPreTraining(config,
                                       seed=training_args.seed,
                                       dtype=getattr(jnp, model_args.dtype))

    data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
        model=model,
        feature_extractor=feature_extractor,
        pad_to_multiple_of=data_args.pad_to_multiple_of)

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    gumbel_rngs = jax.random.split(rng, jax.local_device_count())

    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()

    num_train_steps = len(
        vectorized_datasets["train"]) // train_batch_size * num_epochs

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in [("layer_norm", "scale"),
                                         ("final_layer_norm", "scale")])
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state and define training hyper-parameters
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=adamw)
    num_negatives = model.config.num_negatives
    contrastive_logits_temperature = model.config.contrastive_logits_temperature
    num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
    diversity_loss_weight = model.config.diversity_loss_weight

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng, gumbel_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)

        def loss_fn(params):
            negative_indices = batch.pop("sampled_negative_indices")

            gumbel_temperature = jnp.clip(
                model_args.max_gumbel_temperature *
                model_args.gumbel_temperature_decay**state.step,
                a_min=model_args.min_gumbel_temperature,
            )

            outputs = state.apply_fn(
                **batch,
                gumbel_temperature=gumbel_temperature,
                params=params,
                dropout_rng=dropout_rng,
                gumbel_rng=gumbel_rng,
                train=True,
            )

            contrastive_loss = compute_contrastive_loss(
                outputs.projected_quantized_states,
                outputs.projected_states,
                negative_indices,
                batch["mask_time_indices"],
                contrastive_logits_temperature,
                num_negatives,
            )

            diversity_loss = (num_codevectors -
                              outputs.codevector_perplexity) / num_codevectors
            loss = contrastive_loss + diversity_loss_weight * diversity_loss

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng, new_gumbel_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        negative_indices = batch.pop("sampled_negative_indices")

        outputs = model(**batch, params=params, train=False)

        contrastive_loss = compute_contrastive_loss(
            outputs.projected_quantized_states,
            outputs.projected_states,
            negative_indices,
            batch["mask_time_indices"],
            contrastive_logits_temperature,
            num_negatives,
        )

        diversity_loss = (num_codevectors -
                          outputs.codevector_perplexity) / num_codevectors
        loss = contrastive_loss + diversity_loss_weight * diversity_loss

        # summarize metrics
        metrics = {
            "loss": loss.mean(),
            "codevector_perplexity": outputs.codevector_perplexity
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(vectorized_datasets["train"])
        train_samples_idx = jax.random.permutation(
            input_rng, jnp.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                vectorized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)
            model_inputs = shard(model_inputs.data)

            # Model forward
            state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
                state, model_inputs, dropout_rngs, gumbel_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
                )

                train_metrics = []

        # ======================== Evaluating ==============================
        num_eval_samples = len(vectorized_datasets["validation"])
        eval_samples_idx = jnp.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size)

        eval_metrics = []
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                vectorized_datasets["validation"][int(idx)]
                for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            model_inputs = shard(model_inputs.data)
            metrics = p_eval_step(state.params, model_inputs)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

        # Update progress bar
        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
        )

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(vectorized_datasets["train"]) //
                                train_batch_size)
            write_eval_metric(summary_writer, eval_metrics, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir,
                                  params=params,
                                  push_to_hub=training_args.push_to_hub)
Esempio n. 5
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_partial_programs,
                   FLAGS.max_program_length)
  split_io_shape = (FLAGS.per_device_batch_size,
                    FLAGS.num_strings_per_task,
                    FLAGS.num_partial_programs,
                    FLAGS.max_characters)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
    return inps, outs

  def decode_program(program):
    """Decode program tokens."""
    # Concatenate all partial programs.
    full_program = []
    for p in program:
      full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32))
    full_program = np.concatenate([full_program, [eos_token]], axis=0)

    try:
      return dsl.decode_program(full_program, id_token_table)
    except:  # pylint: disable=bare-except
      return None  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern,
      token_id_table,
      char_id_table,
      num_partial_programs=FLAGS.num_partial_programs)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5)
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = base_models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=not FLAGS.slow_decode)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.DecomposeExpandingLayerTransformer(
      config=eval_config, num_partial_programs=FLAGS.num_partial_programs,
      use_expanding_layer=FLAGS.use_expanding_layer)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  adam_opt_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = adam_opt_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if start_step > 0:
      start_step += 1

  # Build Pretraining Model and Optimizer (if specified)
  # ---------------------------------------------------------------------------
  pretrain_optimizer = None  # Optimizer used for pretrainined
  split_target = None  # Split pretrained model on partial programs.
  if start_step < FLAGS.num_pretrain_steps:
    # Load in pretraining optimizer.
    def filter_fn(path, value):
      del value
      if FLAGS.freeze_encoder and path.startswith('/encoder'):
        return False
      if FLAGS.freeze_decoder and path.startswith('/decoder'):
        return False
      return True
    trainable_weights = optim.ModelParamTraversal(filter_fn)
    pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def))
    pretrain_optimizer = pretrain_opt_def.create(optimizer.target)

    if FLAGS.pretrain_checkpoint_format:
      pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs
      checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs)

      if gfile.isdir(checkpoint_dir):
        # Use the pretrained parameters if no training has occurred yet.
        if start_step == 0:
          restore_paths = []
          if FLAGS.restore_encoder:
            restore_paths.append('target/encoder')
          if FLAGS.restore_decoder:
            restore_paths.append('target/decoder')

          pretrain_optimizer = restore_selected_paths(
              pretrain_optimizer,
              checkpoint_dir=checkpoint_dir,
              restore_paths=restore_paths)
          logging.info('Found model pretrained at %s.', checkpoint_dir)

        if FLAGS.match_split_encoding:
          split_model = models.DecomposeExpandingLayerTransformer(
              config=eval_config, num_partial_programs=1,
              use_expanding_layer=False)
          split_program_shape = (FLAGS.per_device_batch_size,
                                 1,
                                 FLAGS.max_program_length)
          split_initial_variables = jax.jit(split_model.init)(
              init_rng,
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(split_program_shape, jnp.float32))
          split_optimizer = adam_opt_def.create(
              split_initial_variables['params'])
          split_optimizer = checkpoints.restore_checkpoint(
              checkpoint_dir, split_optimizer)
          split_target = split_optimizer.target
      else:
        logging.warn('Could not find model at %s.', checkpoint_dir)

    if FLAGS.match_split_encoding and (split_target is None):
      raise RuntimeError('We could not load the pretrained checkpoint, '
                         'which is needed to match split embeddings.')

  learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr)
  p_pretrain_step = jax.pmap(
      functools.partial(
          pretrain_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer,
          split_params=split_target),
      axis_name='batch')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          num_partial_programs=FLAGS.num_partial_programs,
          eos_token=eos_token,
          config=eval_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_split_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=False,
          use_split_encoding=True,
          split_params=split_target),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  # Replicate optimizer.
  if pretrain_optimizer:
    pretrain_optimizer = jax_utils.replicate(pretrain_optimizer)

  optimizer = jax_utils.replicate(optimizer)

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs, split_outputs = (
        common_utils.shard(next(train_iter)))

    if step < FLAGS.num_pretrain_steps:
      pretrain_optimizer, metrics, train_rngs = p_pretrain_step(
          pretrain_optimizer, inputs, outputs, programs,
          split_outputs=split_outputs,
          pretrain_rng=train_rngs)
    else:
      optimizer, metrics, train_rngs = p_train_step(
          optimizer, inputs, outputs, programs,
          train_rng=train_rngs)

    metrics_all.append(metrics)
    is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1
    is_last_step = step == FLAGS.num_train_steps - 1

    if is_last_pretrain_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)

    # Save a Checkpoint
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or (step % FLAGS.log_freq != 0 and not is_last_step and
                    not is_last_pretrain_step):
      continue

    optimizer = maybe_copy_model_from_pretraining(
        optimizer, pretrain_optimizer, step, adam_opt_def)

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()

    eval_summary = evaluate(
        p_eval_step=p_eval_step,
        target=optimizer.target,
        eval_ds=eval_ds)
    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [1, 10, 12, 24, 48, 96]:
      t_inference_start = time.time()

      pred_acc, message = predict_and_compute_score(
          p_pred_step=p_pred_step,
          p_init_cache=p_init_cache,
          target=optimizer.target,
          predict_ds=predict_ds,
          decode_io=decode_io,
          decode_program=decode_program,
          beam_size=beam_size,
          num_partial_programs=FLAGS.num_partial_programs,
          use_best_first_search=FLAGS.best_first_search,
          slow_decode=FLAGS.slow_decode)

      # Write to tensorboard.
      if jax.host_id() == 0:
        slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
        logging.info(
            'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f',
            slow_or_fast, beam_size, time.time() - t_inference_start, step,
            pred_acc)
        beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search'
        summary_writer.scalar(
            'predict-{}/score-{}-{}'.format(slow_or_fast,
                                            beam_search_or_bfs,
                                            beam_size),
            pred_acc, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()

      if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding:
        pred_acc, message = predict_and_compute_score(
            p_pred_step=p_split_pred_step,
            p_init_cache=p_init_cache,
            target=optimizer.target,
            predict_ds=predict_ds,
            decode_io=decode_io,
            decode_program=decode_program,
            beam_size=beam_size,
            num_partial_programs=FLAGS.num_partial_programs,
            use_best_first_search=FLAGS.best_first_search,
            slow_decode=FLAGS.slow_decode)

        # Write to tensorboard.
        if jax.host_id() == 0:
          slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
          beam_search_or_bfs = ('bfs' if FLAGS.best_first_search
                                else 'beam-search')
          summary_writer.scalar(
              'predict-split-{}/score-{}-{}'.format(slow_or_fast,
                                                    beam_search_or_bfs,
                                                    beam_size),
              pred_acc, step)
          summary_writer.text('samples-split-{}'.format(beam_size),
                              '\n------\n'.join(message), step)
          summary_writer.flush()
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )

        summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxAutoModelForMaskedLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        model = FlaxAutoModelForMaskedLM.from_config(config,
                                                     seed=training_args.seed,
                                                     dtype=getattr(
                                                         jnp,
                                                         model_args.dtype))

    # Store some constant
Esempio n. 7
0
def main(args):
  logdir = os.path.join(args.logdir, args.name)
  logger = logging.setup_logger(logdir)
  logger.info(args)

  logger.info(f'Available devices: {jax.devices()}')

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train')

  ds_train = input_pipeline.get_data(
      dataset=args.dataset,
      mode='train',
      repeats=None,
      mixup_alpha=args.mixup_alpha,
      batch_size=args.batch,
      shuffle_buffer=args.shuffle_buffer,
      tfds_data_dir=args.tfds_data_dir,
      tfds_manual_dir=args.tfds_manual_dir)
  batch = next(iter(ds_train))
  logger.info(ds_train)
  ds_test = input_pipeline.get_data(
      dataset=args.dataset,
      mode='test',
      repeats=1,
      batch_size=args.batch_eval,
      tfds_data_dir=args.tfds_data_dir,
      tfds_manual_dir=args.tfds_manual_dir)
  logger.info(ds_test)

  # Build VisionTransformer architecture
  model = models.KNOWN_MODELS[args.model]
  VisionTransformer = model.partial(num_classes=dataset_info['num_classes'])
  _, params = VisionTransformer.init_by_shape(
      jax.random.PRNGKey(0),
      # Discard the "num_local_devices" dimension for initialization.
      [(batch['image'].shape[1:], batch['image'].dtype.name)])

  pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=params,
      model_config=models.CONFIGS[args.model],
      logger=logger)

  # pmap replicates the models over all TPUs/GPUs
  vit_fn_repl = jax.pmap(VisionTransformer.call)
  update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps)

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params)
  opt_repl = flax_utils.replicate(opt)

  # Delete referenes to the objects that are not needed anymore
  del opt
  del params

  def copyfiles(paths):
    """Small helper to copy files to args.copy_to using tf.io.gfile."""
    if not args.copy_to:
      return
    for path in paths:
      to_path = os.path.join(args.copy_to, args.name, os.path.basename(path))
      tf.io.gfile.makedirs(os.path.dirname(to_path))
      tf.io.gfile.copy(path, to_path, overwrite=True)
      logger.info(f'Copied {path} to {to_path}.')

  total_steps = args.total_steps or (
      input_pipeline.DATASET_PRESETS[args.dataset]['total_steps'])

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr,
                                              args.decay_type,
                                              args.warmup_steps)
  lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
  update_rngs = jax.random.split(
      jax.random.PRNGKey(0), jax.local_device_count())

  # Run training loop
  writer = metric_writers.create_default_writer(logdir, asynchronous=False)
  writer.write_hparams({k: v for k, v in vars(args).items() if v is not None})
  logger.info('Starting training loop; initial compile can take a while...')
  t0 = time.time()

  for step, batch, lr_repl in zip(
      range(1, total_steps + 1),
      input_pipeline.prefetch(ds_train, args.prefetch), lr_iter):

    opt_repl, loss_repl, update_rngs = update_fn_repl(
        opt_repl, lr_repl, batch, update_rngs)

    if step == 1:
      logger.info(f'First step took {time.time() - t0:.1f} seconds.')
      t0 = time.time()
    if args.progress_every and step % args.progress_every == 0:
      writer.write_scalars(step, dict(train_loss=float(loss_repl[0])))
      done = step / total_steps
      logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '
                  f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')
      copyfiles(glob.glob(f'{logdir}/*'))

    # Run eval step
    if ((args.eval_every and step % args.eval_every == 0) or
        (step == total_steps)):

      accuracy_test = np.mean([
          c for batch in input_pipeline.prefetch(ds_test, args.prefetch)
          for c in (
              np.argmax(vit_fn_repl(opt_repl.target, batch['image']),
                        axis=2) == np.argmax(batch['label'], axis=2)).ravel()
      ])

      lr = float(lr_repl[0])
      logger.info(f'Step: {step} '
                  f'Learning rate: {lr:.7f}, '
                  f'Test accuracy: {accuracy_test:0.5f}')
      writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr))
      copyfiles(glob.glob(f'{logdir}/*'))

  if args.output:
    checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output)
    logger.info(f'Stored fine tuned checkpoint to {args.output}')
    copyfiles([args.output])
def get_data(*,
             dataset,
             mode,
             repeats,
             batch_size,
             mixup_alpha=0,
             tfds_manual_dir=None,
             inception_crop=True):
    """Returns dataset for training/eval.

  Args:
    dataset: Dataset name. Additionally to the requirement that this dataset
      must be in tensorflow_datasets, the dataset must be registered in
      `DATASET_PRESETS` (specifying crop size etc).
    mode: Must be "train" or "test".
    repeats: How many times the dataset should be repeated. For indefinite
      repeats specify None.
    batch_size: Global batch size. Note that the returned dataset will have
      dimensions [local_devices, batch_size / local_devices, ...].
    mixup_alpha: Coefficient for mixup combination. See 
      https://arxiv.org/abs/1710.09412
    tfds_manual_dir: Optional directory that contains downloaded files for
      tensorflow_dataset preparation.
    inception_crop: If set to True, tf.image.sample_distorted_bounding_box()
      will be used. If set to False, tf.image.random_crop() will be used.
  """

    preset = DATASET_PRESETS.get(dataset)
    if preset is None:
        raise KeyError(
            f'Please add "{dataset}" to {__name__}.DATASET_PRESETS"')
    split = preset[mode]
    resize_size = preset['resize']
    crop_size = preset['crop']
    dataset_info = get_dataset_info(dataset, split)

    data_builder = tfds.builder(dataset)
    data_builder.download_and_prepare(
        download_config=tfds.download.DownloadConfig(
            manual_dir=tfds_manual_dir))
    data = data_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    decoder = data_builder.info.features['image'].decode_example

    def _pp(data):
        im = decoder(data['image'])
        if mode == 'train':
            if inception_crop:
                channels = im.shape[-1]
                begin, size, _ = tf.image.sample_distorted_bounding_box(
                    tf.shape(im),
                    tf.zeros([0, 0, 4], tf.float32),
                    area_range=(0.05, 1.0),
                    min_object_covered=0,  # Don't enforce a minimum area.
                    use_image_if_no_bounding_boxes=True)
                im = tf.slice(im, begin, size)
                # Unfortunately, the above operation loses the depth-dimension. So we
                # need to restore it the manual way.
                im.set_shape([None, None, channels])
                im = tf.image.resize(im, [crop_size, crop_size])
            else:
                im = tf.image.resize(im, [resize_size, resize_size])
                im = tf.image.random_crop(im, [crop_size, crop_size, 3])
                im = tf.image.flip_left_right(im)
        else:
            # usage of crop_size here is intentional
            im = tf.image.resize(im, [crop_size, crop_size])
        im = (im - 127.5) / 127.5
        label = tf.one_hot(data['label'], dataset_info['num_classes'])  # pylint: disable=no-value-for-parameter
        return {'image': im, 'label': label}

    data = data.repeat(repeats)
    if mode == 'train':
        data = data.shuffle(min(dataset_info['num_examples'], MAX_IN_MEMORY))
    data = data.map(_pp, tf.data.experimental.AUTOTUNE)
    data = data.batch(batch_size, drop_remainder=True)

    def _mixup(data):
        beta_dist = tfp.distributions.Beta(mixup_alpha, mixup_alpha)
        beta = tf.cast(beta_dist.sample([]), tf.float32)
        data['image'] = (beta * data['image'] +
                         (1 - beta) * tf.reverse(data['image'], axis=[0]))
        data['label'] = (beta * data['label'] +
                         (1 - beta) * tf.reverse(data['label'], axis=[0]))
        return data

    if mixup_alpha is not None and mixup_alpha > 0.0 and mode == 'train':
        data = data.map(_mixup, tf.data.experimental.AUTOTUNE)

    # Shard data such that it can be distributed accross devices
    num_devices = jax.local_device_count()

    def _shard(data):
        data['image'] = tf.reshape(data['image'],
                                   [num_devices, -1, crop_size, crop_size, 3])
        data['label'] = tf.reshape(
            data['label'], [num_devices, -1, dataset_info['num_classes']])
        return data

    if num_devices is not None:
        data = data.map(_shard, tf.data.experimental.AUTOTUNE)

    return data.prefetch(1)
Esempio n. 9
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)
    target_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("translate_and_bleu"):
                    exemplars, bleu_score = translate_and_calculate_bleu(
                        p_pred_step=p_pred_step,
                        p_init_cache=p_init_cache,
                        target=optimizer.target,
                        predict_ds=predict_ds,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_scalars(step, {"bleu": bleu_score})
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.host_id(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
Esempio n. 10
0
  def compute_preconditioners_from_statistics(self, states, hps, step):
    """Compute preconditioners for statistics."""
    statistics = []
    num_statistics_per_state = []
    original_shapes = []
    exponents = []
    max_size = 0
    prev_preconditioners = []
    for state in states:
      num_statistics = len(state.statistics)
      num_statistics_per_state.append(num_statistics)
      original_shapes_for_state = []
      if num_statistics > 0:
        for statistic in state.statistics:
          exponents.append(2 * num_statistics if hps.exponent_override ==
                           0 else hps.exponent_override)
          original_shapes_for_state.append(statistic.shape)
          max_size = max(max_size, statistic.shape[0])
        statistics.extend(state.statistics)
        prev_preconditioners.extend(state.preconditioners)
        original_shapes.extend(original_shapes_for_state)
    num_statistics = len(statistics)

    def pack(mat, max_size):
      """Pack a matrix to a max_size for inverse on TPUs with static shapes.

      Args:
        mat: Matrix for computing inverse pth root.
        max_size: Matrix size to pack to.

      Returns:
        Given M returns [[M, 0], [0, I]]
      """
      size = mat.shape[0]
      assert size <= max_size
      if size == max_size:
        return mat
      pad_size = max_size - size
      zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
      zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
      eye = jnp.eye(pad_size, dtype=mat.dtype)
      mat = jnp.concatenate([mat, zs1], 1)
      mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
      return mat

    if not hps.batch_axis_name:
      num_devices = jax.local_device_count()
    else:
      num_devices = lax.psum(1, hps.batch_axis_name)

    # Pad statistics and exponents to next multiple of num_devices.
    packed_statistics = [pack(stat, max_size) for stat in statistics]
    to_pad = -num_statistics % num_devices
    packed_statistics.extend([
        jnp.eye(max_size, dtype=packed_statistics[0].dtype)
        for _ in range(to_pad)
    ])
    exponents.extend([1 for _ in range(to_pad)])

    # Batch statistics and exponents so that so that leading axis is
    # num_devices.
    def _batch(statistics, exponents, num_devices):
      assert len(statistics) == len(exponents)
      n = len(statistics)
      b = int(n / num_devices)
      batched_statistics = [
          jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
      ]
      batched_exponents = [
          jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
      ]
      return jnp.stack(batched_statistics), jnp.stack(batched_exponents)

    # Unbatch values across leading axis and return a list of elements.
    def _unbatch(batched_values):
      b1, b2 = batched_values.shape[0], batched_values.shape[1]
      results = []
      for v_array in jnp.split(batched_values, b1, 0):
        for v in jnp.split(jnp.squeeze(v_array), b2, 0):
          results.append(jnp.squeeze(v))
      return results

    all_statistics, all_exponents = _batch(packed_statistics, exponents,
                                           num_devices)

    def _matrix_inverse_pth_root(xs, ps):
      mi_pth_root = lambda x, y: matrix_inverse_pth_root(  # pylint: disable=g-long-lambda
          x, y, ridge_epsilon=hps.matrix_eps)
      preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
      return preconditioners, errors

    if not hps.batch_axis_name:
      preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)(
          all_statistics, all_exponents)
      preconditioners_flat = _unbatch(preconditioners)
      errors_flat = _unbatch(errors)
    else:

      def _internal_inverse_pth_root_all():
        preconditioners = jnp.array(all_statistics)
        current_replica = lax.axis_index(hps.batch_axis_name)
        preconditioners, errors = _matrix_inverse_pth_root(
            all_statistics[current_replica], all_exponents[current_replica])
        preconditioners = jax.lax.all_gather(preconditioners,
                                             hps.batch_axis_name)
        errors = jax.lax.all_gather(errors, hps.batch_axis_name)
        preconditioners_flat = _unbatch(preconditioners)
        errors_flat = _unbatch(errors)
        return preconditioners_flat, errors_flat

      if hps.preconditioning_compute_steps == 1:
        preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
      else:
        # Passing statistics instead of preconditioners as they are similarly
        # shaped tensors, as error we are passing is the threshold these will
        # be ignored.
        preconditioners_init = packed_statistics
        errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] *
                       len(packed_statistics))
        init_state = [preconditioners_init, errors_init]
        perform_step = step % hps.preconditioning_compute_steps == 0
        preconditioners_flat, errors_flat = self.fast_cond(
            perform_step, _internal_inverse_pth_root_all, init_state)

    def _skip(error):
      return jnp.logical_or(
          jnp.isnan(error),
          error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(error.dtype)

    def _select_preconditioner(error, new_p, old_p):
      return lax.cond(
          _skip(error), lambda _: old_p, lambda _: new_p, operand=None)

    new_preconditioners_flat = []
    for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
                                       prev_preconditioners, errors_flat):
      new_preconditioners_flat.append(
          _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))

    assert len(states) == len(num_statistics_per_state)
    assert len(new_preconditioners_flat) == num_statistics

    # Add back empty preconditioners so we that we can set the optimizer state.
    preconditioners_for_states = []
    idx = 0
    for num_statistics, state in zip(num_statistics_per_state, states):
      if num_statistics == 0:
        preconditioners_for_states.append([])
      else:
        preconditioners_for_state = new_preconditioners_flat[idx:idx +
                                                             num_statistics]
        assert len(state.statistics) == len(preconditioners_for_state)
        preconditioners_for_states.append(preconditioners_for_state)
        idx += num_statistics
    new_states = []
    for state, new_preconditioners in zip(states, preconditioners_for_states):
      new_states.append(
          _ShampooDefaultParamState(state.diagonal_statistics, state.statistics,
                                    new_preconditioners,
                                    state.diagonal_momentum, state.momentum))

    return new_states
Esempio n. 11
0
        def update_step(
            state,
            transitions,
            in_initial_bc_iters,
        ):
            def reshape_for_devices(t):
                rest_t_shape = list(t.shape[1:])
                new_shape = [
                    num_devices,
                    t.shape[0] // num_devices,
                ] + rest_t_shape
                return jnp.reshape(t, new_shape)

            transitions = jax.tree_map(reshape_for_devices, transitions)

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)
            if adaptive_entropy_coefficient:
                alpha = jnp.exp(state.alpha_params)
            else:
                alpha = entropy_coefficient

            key_critic = jax.random.split(key_critic, jax.local_device_count())
            # print(jax.tree_map(lambda t: t.shape, state.q_params))
            total_critic_loss_and_aux, q_params, new_target_q_params, q_optimizer_state = pmapped_critic_update(
                state.q_params,
                state.target_q_params,
                state.q_optimizer_state,
                state.policy_params,
                alpha,
                cql_alpha,
                transitions,
                key_critic,
            )
            # print(jax.tree_map(lambda t: t.shape, q_params))
            total_critic_loss_and_aux = jax.tree_map(
                jnp.mean, total_critic_loss_and_aux)

            key_actor = jax.random.split(key_actor, jax.local_device_count())
            # if in_initial_bc_iters:
            #   pmapped_actor_update = pmapped_actor_update_in_bc_iters
            # else:
            #   pmapped_actor_update = pmapped_actor_update_after_bc_iters
            policy_params, policy_optimizer_state, actor_loss, min_q, avg_log_prob, sn, new_snr_state = pmapped_actor_update(
                in_initial_bc_iters,
                state.policy_params,
                state.policy_optimizer_state,
                state.q_params,
                state.target_q_params,
                alpha,
                transitions,
                state.snr_state,
                key_actor,
            )
            avg_log_prob = jnp.mean(avg_log_prob)

            critic_loss_aux = total_critic_loss_and_aux[1]
            # metrics = {
            #     'critic_loss': critic_loss_aux['critic_loss'],
            #     'cql_loss': critic_loss_aux['cql_loss'],
            #     'actor_loss': actor_loss,
            # }
            metrics = OrderedDict()
            metrics['actor_loss'] = jnp.mean(actor_loss)
            metrics['avg_log_prob'] = avg_log_prob
            metrics['total_critic_loss'] = total_critic_loss_and_aux[0]
            metrics['critic_loss'] = critic_loss_aux['critic_loss']
            metrics['cql_loss'] = critic_loss_aux['cql_loss']
            metrics['q/avg'] = jnp.mean(min_q)
            metrics['q/std'] = jnp.std(min_q)
            metrics['q/max'] = jnp.max(min_q)
            metrics['q/min'] = jnp.min(min_q)
            metrics['SNR/loss'] = jnp.mean(sn)

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=new_target_q_params,
                key=key,
                snr_state=new_snr_state,
            )
            if adaptive_entropy_coefficient and (not in_initial_bc_iters):
                # Apply alpha gradients
                alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
                                                     avg_log_prob)
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                alpha_params = optax.apply_updates(state.alpha_params,
                                                   alpha_update)
                metrics['alpha_loss'] = alpha_loss
                metrics['alpha'] = jnp.exp(alpha_params)
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=alpha_params)
            else:
                metrics['alpha_loss'] = 0.
                metrics['alpha'] = jnp.exp(state.alpha_params)
                new_state = new_state._replace(
                    alpha_optimizer_state=state.alpha_optimizer_state,
                    alpha_params=state.alpha_params)

            # metrics['observations_mean'] = jnp.mean(
            #     utils.batch_concat(
            #         jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)),
            #                      transitions.observation)))
            # metrics['observations_std'] = jnp.mean(
            #     utils.batch_concat(
            #         jax.tree_map(lambda x: jnp.std(x, axis=0),
            #                      transitions.observation)))
            # metrics['next_observations_mean'] = jnp.mean(
            #     utils.batch_concat(
            #         jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)),
            #                      transitions.next_observation)))
            # metrics['next_observations_std'] = jnp.mean(
            #     utils.batch_concat(
            #         jax.tree_map(lambda x: jnp.std(x, axis=0),
            #                      transitions.next_observation)))

            return new_state, metrics
Esempio n. 12
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, *rays), axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr_init)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            summary_writer.scalar("train_loss", stats.loss[0], step)
            summary_writer.scalar("train_psnr", stats.psnr[0], step)
            summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
            summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
            summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
            avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
            avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
            stats_trace = []
            summary_writer.scalar("train_avg_loss", avg_loss, step)
            summary_writer.scalar("train_avg_psnr", avg_psnr, step)
            summary_writer.scalar("learning_rate", lr, step)
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"i_loss={stats.loss[0]:0.5f}, " +
                  f"avg_loss={avg_loss:0.5f}, " +
                  f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                  f"lr={lr:0.2e}, " + f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
Esempio n. 13
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name, token=training_args.hub_token
            )
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called
    # 'tokens' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    if raw_datasets["train"] is not None:
        column_names = raw_datasets["train"].column_names
        features = raw_datasets["train"].features
    else:
        column_names = raw_datasets["validation"].column_names
        features = raw_datasets["validation"].features

    if data_args.text_column_name is not None:
        text_column_name = data_args.text_column_name
    elif "tokens" in column_names:
        text_column_name = "tokens"
    else:
        text_column_name = column_names[0]

    if data_args.label_column_name is not None:
        label_column_name = data_args.label_column_name
    elif f"{data_args.task_name}_tags" in column_names:
        label_column_name = f"{data_args.task_name}_tags"
    else:
        label_column_name = column_names[1]

    # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
    # unique labels.
    def get_label_list(labels):
        unique_labels = set()
        for label in labels:
            unique_labels = unique_labels | set(label)
        label_list = list(unique_labels)
        label_list.sort()
        return label_list

    if isinstance(features[label_column_name].feature, ClassLabel):
        label_list = features[label_column_name].feature.names
        # No need to convert the labels since they are already ints.
        label_to_id = {i: i for i in range(len(label_list))}
    else:
        label_list = get_label_list(raw_datasets["train"][label_column_name])
        label_to_id = {l: i for i, l in enumerate(label_list)}
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        label2id=label_to_id,
        id2label={i: l for l, i in label_to_id.items()},
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
    if config.model_type in {"gpt2", "roberta"}:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            add_prefix_space=True,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    model = FlaxAutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Preprocessing the datasets
    # Tokenize all texts and align the labels with them.
    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(
            examples[text_column_name],
            max_length=data_args.max_seq_length,
            padding="max_length",
            truncation=True,
            # We use this argument because the texts in our dataset are lists of words (with a label for each word).
            is_split_into_words=True,
        )

        labels = []

        for i, label in enumerate(examples[label_column_name]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                # ignored in the loss function.
                if word_idx is None:
                    label_ids.append(-100)
                # We set the label for the first token of each word.
                elif word_idx != previous_word_idx:
                    label_ids.append(label_to_id[label[word_idx]])
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
                previous_word_idx = word_idx

            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    processed_raw_datasets = raw_datasets.map(
        tokenize_and_align_labels,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        remove_columns=raw_datasets["train"].column_names,
        desc="Running tokenizer on dataset",
    )

    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(training_args.output_dir)
            summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    def write_train_metric(summary_writer, train_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    def write_eval_metric(summary_writer, eval_metrics, step):
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args)

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    metric = load_metric("seqeval")

    def get_labels(y_pred, y_true):
        # Transform predictions and references tensos to numpy arrays

        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        return true_predictions, true_labels

    def compute_metrics():
        results = metric.compute()
        if data_args.return_entity_level_metrics:
            # Unpack nested dictionaries
            final_results = {}
            for key, value in results.items():
                if isinstance(value, dict):
                    for n, v in value.items():
                        final_results[f"{key}_{n}"] = v
                else:
                    final_results[key] = value
            return final_results
        else:
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    train_time = 0
    step_per_epoch = len(train_dataset) // train_batch_size
    total_steps = step_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    for epoch in epochs:

        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # train
        for step, batch in enumerate(
            tqdm(
                train_data_collator(input_rng, train_dataset, train_batch_size),
                total=step_per_epoch,
                desc="Training...",
                position=1,
            )
        ):
            state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = (epoch * step_per_epoch) + (step + 1)

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:

                eval_metrics = {}
                # evaluate
                for batch in tqdm(
                    eval_data_collator(eval_dataset, eval_batch_size),
                    total=len(eval_dataset) // eval_batch_size,
                    desc="Evaluating ...",
                    position=2,
                ):
                    labels = batch.pop("labels")
                    predictions = p_eval_step(state, batch)
                    predictions = np.array([pred for pred in chain(*predictions)])
                    labels = np.array([label for label in chain(*labels)])
                    labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
                    preds, refs = get_labels(predictions, labels)
                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )

                # evaluate also on leftover examples (not divisible by batch_size)
                num_leftover_samples = len(eval_dataset) % eval_batch_size

                # make sure leftover batch is evaluated on one device
                if num_leftover_samples > 0 and jax.process_index() == 0:
                    # take leftover samples
                    batch = eval_dataset[-num_leftover_samples:]
                    batch = {k: np.array(v) for k, v in batch.items()}

                    labels = batch.pop("labels")
                    predictions = eval_step(unreplicate(state), batch)
                    labels = np.array(labels)
                    labels[np.array(batch["attention_mask"]) == 0] = -100
                    preds, refs = get_labels(predictions, labels)
                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )

                eval_metrics = compute_metrics()

                if data_args.return_entity_level_metrics:
                    logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
                else:
                    logger.info(
                        f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
                    )

                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir, params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
        epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"

    # Eval after training
    if training_args.do_eval:
        eval_metrics = {}
        eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
        for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            predictions = np.array([pred for pred in chain(*predictions)])
            labels = np.array([label for label in chain(*labels)])
            labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
            preds, refs = get_labels(predictions, labels)
            metric.add_batch(predictions=preds, references=refs)

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: np.array(v) for k, v in batch.items()}

            labels = np.array(batch.pop("labels"))
            predictions = eval_step(unreplicate(state), batch)
            labels[np.array(batch["attention_mask"]) == 0] = -100
            preds, refs = get_labels(predictions, labels)
            metric.add_batch(predictions=preds, references=refs)

        eval_metrics = compute_metrics()

        if jax.process_index() == 0:
            eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Esempio n. 14
0
def main(argv):
  del argv  # unused arg

  config = FLAGS.config

  # Unpack total and warmup steps
  # TODO(nband): revert this to separate arguments.
  total_steps = config.total_and_warmup_steps[0]
  warmup_steps = config.total_and_warmup_steps[1]
  del config.total_and_warmup_steps
  config.total_steps = total_steps
  config.lr.warmup_steps = warmup_steps

  # Wandb and Checkpointing Setup
  output_dir = FLAGS.output_dir
  wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
  tf.io.gfile.makedirs(output_dir)
  logging.info('Saving checkpoints at %s', output_dir)

  # Dataset Split Flags
  dist_shift = config.distribution_shift
  print(f'Distribution Shift: {dist_shift}.')
  dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift)

  # LR / Optimization Flags
  batch_size = config.batch_size
  grad_clip_norm = config.grad_clip_norm
  weight_decay = config.weight_decay
  print('Standard wandb hyperparameters:')
  print({
      'batch_size': batch_size,
      'grad_clip_norm': grad_clip_norm,
      'weight_decay': weight_decay,
      'total_steps': config.total_steps,
      'lr': config.lr
  })
  print('SNGP Params:', config.gp_layer)

  # Reweighting loss for class imbalance
  # class_reweight_mode = config.class_reweight_mode
  # if class_reweight_mode == 'constant':
  #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  # else:
  #   class_weights = None

  # Shows the number of available devices.
  # In a CPU/GPU runtime this will be a single device.
  # In a TPU runtime this will be 8 cores.
  print('Number of Jax local devices:', jax.local_devices())

  # TODO(nband): fix sigmoid loss issues.
  assert config.get('loss', None) == 'softmax_xent'

  seed = config.seed
  rng = jax.random.PRNGKey(seed)
  tf.random.set_seed(seed)

  if config.get('data_dir'):
    logging.info('data_dir=%s', config.data_dir)
  logging.info('Output dir: %s', output_dir)

  save_checkpoint_path = None
  if config.get('checkpoint_steps'):
    tf.io.gfile.makedirs(output_dir)
    save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

  # Create an asynchronous multi-metric writer.
  writer = metric_writers.create_default_writer(
      output_dir, just_logging=jax.process_index() > 0)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  def write_note(note):
    if jax.process_index() == 0:
      logging.info('NOTE: %s', note)

  write_note('Initializing...')

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get('keep_checkpoint_steps'):
    assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
        f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

  batch_size_eval = config.get('batch_size_eval', batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
                     f'be divisible by device number ({jax.device_count()})')

  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = batch_size_eval // jax.process_count()
  logging.info(
      'Global batch size %d on %d hosts results in %d local batch size. '
      'With %d dev per host (%d dev total), that is a %d per-device batch size.',
      batch_size,
      jax.process_count(), local_batch_size, jax.local_device_count(),
      jax.device_count(), local_batch_size // jax.local_device_count())

  write_note('Initializing preprocessing function...')
  # Same preprocessing function for training and evaluation
  preproc_fn = preprocess_spec.parse(
      spec=config.pp_train, available_ops=preprocess_utils.all_ops())

  write_note('Initializing train dataset...')
  rng, train_ds_rng = jax.random.split(rng)
  train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
  train_base_dataset = ub.datasets.get(
      dataset_names['in_domain_dataset'],
      split=split_names['train_split'],
      data_dir=config.get('data_dir'))
  train_dataset_builder = train_base_dataset._dataset_builder  # pylint: disable=protected-access
  train_ds = input_utils.get_data(
      dataset=train_dataset_builder,
      split=split_names['train_split'],
      rng=train_ds_rng,
      process_batch_size=local_batch_size,
      preprocess_fn=preproc_fn,
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch_size=config.get('prefetch_to_host', 2),
      data_dir=config.get('data_dir'))
  logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

  # Start prefetching already.
  train_iter = input_utils.start_input_pipeline(
      train_ds, config.get('prefetch_to_device', 1))

  write_note('Initializing val dataset(s)...')

  # Load in-domain and OOD validation and/or test datasets.
  # Please specify the desired shift (Country Shift or Severity Shift)
  # in the config.
  eval_iter_splits = vit_utils.init_evaluation_datasets(
      use_validation=config.use_validation,
      use_test=config.use_test,
      dataset_names=dataset_names,
      split_names=split_names,
      config=config,
      preproc_fn=preproc_fn,
      batch_size_eval=batch_size_eval,
      local_batch_size_eval=local_batch_size_eval)

  ntrain_img = input_utils.get_num_examples(
      train_dataset_builder,
      split=split_names['train_split'],
      process_batch_size=local_batch_size,
      data_dir=config.get('data_dir'))
  steps_per_epoch = ntrain_img / batch_size

  if config.get('num_epochs'):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get('total_steps'), 'Set either num_epochs or total_steps'
  else:
    total_steps = config.total_steps

  logging.info('Total train data points: %d', ntrain_img)
  logging.info(
      'Running for %d steps, that means %f epochs and %d steps per epoch',
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

  write_note('Initializing model...')
  logging.info('config.model = %s', config.get('model'))

  # Specify Gaussian process layer configs.
  gp_config = config.get('gp_layer', {})
  model_dict = vit_utils.initialize_model('sngp', config)
  model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer']

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    image_size = tuple(train_ds.element_spec['image'].shape[2:])
    logging.info('image_size = %s', image_size)
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    variables = model.init(rng, dummy_input, train=False)
    # Split model parameters into trainable and untrainable collections.
    states, params = variables.pop('params')
    del variables

    # Set bias in the head to a low value, such that loss is small initially.
    params = flax.core.unfreeze(params)
    if use_gp_layer:
      # Modify the head parameter in the GP head.
      params['head']['output_layer']['bias'] = jnp.full_like(
          params['head']['output_layer']['bias'],
          config.get('init_head_bias', 0))
    else:
      params['head']['bias'] = jnp.full_like(
          params['head']['bias'], config.get('init_head_bias', 0))

    return params, states

  rng, rng_init = jax.random.split(rng)
  params_cpu, states_cpu = init(rng_init)

  if jax.process_index() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    writer.write_scalars(step=0, scalars={'num_params': num_params})

  @functools.partial(jax.pmap, axis_name='batch')
  def evaluation_fn(params, states, images, labels):
    variable_dict = {'params': flax.core.freeze(params), **states}
    logits, out = model.apply(
        variable_dict,
        images,
        train=False,
        mean_field_factor=gp_config.get('mean_field_factor', -1.))
    losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses, axis_name='batch')
    top1_idx = jnp.argmax(logits, axis=1)

    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]

    ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
    n = batch_size_eval
    metric_args = jax.lax.all_gather([
        logits, labels, out['pre_logits']], axis_name='batch')
    return ncorrect, loss, n, metric_args

  # Load the optimizer from flax.
  opt_name = config.get('optim_name')
  write_note(f'Initializing {opt_name} optimizer...')
  opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  weight_decay_rules = config.get('weight_decay', []) or []
  rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1.
  weight_decay_fn = train_utils.get_weight_decay_fn(
      weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

  @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
  def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
    """Update step."""
    measurements = {}

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch'))

    def loss_fn(params, states, images, labels):
      # Specify mutable collection to update untrainable GP parameters.
      variable_dict = {'params': flax.core.freeze(params), **states}
      model_results, updated_states = model.apply(
          variable_dict,
          images,
          train=True,
          rngs={'dropout': rng_model_local},
          mutable=list(states.keys()),
          mean_field_factor=gp_config.get('mean_field_factor', -1.))

      logits, _ = model_results
      loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
          logits=logits, labels=labels)
      return loss, updated_states

    # Performs exact covariance update (i.e., reset precision matrix resetting
    # at begining of new epoch) if covmat_momentum is a null value.
    if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0:
      # Resets precision matrix to Identity * ridge_penalty if at the begining
      # of a new epoch. This should be done before accumulate gradient.
      ridge_penalty = gp_config.get('ridge_penalty', 1.)
      prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix']
      prec_mat_new = (
          (1. - reset_covmat) * prec_mat_old +
          reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

      states = flax.core.unfreeze(states)
      states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix'] = prec_mat_new
      states = flax.core.freeze(states)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    (l, s), g = vit_utils.accumulate_gradient_with_states(
        jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images,
        labels, config.get('grad_accum_steps'))
    l, g = jax.lax.pmean((l, g), axis_name='batch')

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements['l2_grads'] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if grad_clip_norm is not None:
      g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)
    opt = opt.replace(target=weight_decay_fn(opt.target, lr))

    params, _ = jax.tree_flatten(opt.target)
    measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))
    measurements['reset_covmat'] = reset_covmat

    return opt, s, l, rng, measurements

  # Set config checkpoint resume path, if provided in args.
  if config.resume_checkpoint_path is not None:
    config.resume = config.resume_checkpoint_path

  default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias',
                           'head/kernel', 'head/bias')
  rng, train_loop_rngs = jax.random.split(rng)
  checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
      train_loop_rngs=train_loop_rngs,
      save_checkpoint_path=save_checkpoint_path,
      init_optimizer=opt_cpu,
      init_params=params_cpu,
      init_fixed_model_states=states_cpu,
      default_reinit_params=default_reinit_params,
      config=config)
  train_loop_rngs = checkpoint_data.train_loop_rngs
  opt_cpu = checkpoint_data.optimizer
  states_cpu = checkpoint_data.fixed_model_states
  accumulated_train_time = checkpoint_data.accumulated_train_time

  write_note('Adapting the checkpoint model...')
  adapted_params = checkpoint_utils.adapt_upstream_architecture(
      init_params=params_cpu,
      loaded_params=opt_cpu.target)
  opt_cpu = opt_cpu.replace(target=adapted_params)

  write_note('Kicking off misc stuff...')
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  if first_step == 0 and jax.process_index() == 0:
    writer.write_hparams(dict(config))
  chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                              accumulated_train_time)
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=output_dir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                    **config.get('lr', {}))

  # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
  # necessary for TPUs.
  lr_iter = train_utils.prefetch_scalar(
      map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1))

  # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
  reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
  reset_covmat_iter = train_utils.prefetch_scalar(
      map(reset_covmat_fn, range(first_step, total_steps)),
      nprefetch=config.get('prefetch_to_device', 1))

  write_note(f'Replicating...\n{chrono.note}')
  opt_repl = flax.jax_utils.replicate(opt_cpu)
  states_repl = flax.jax_utils.replicate(states_cpu)

  checkpoint_writer = None

  # Note: we return the train loss, val loss, and fewshot best l2s for use in
  # reproducibility unit tests.
  # train_loss = -jnp.inf
  # val_loss = -jnp.inf
  # results = {'dummy': {(0, 1): -jnp.inf}}

  write_note(f'First step compilations...\n{chrono.note}')
  logging.info('first_step = %s', first_step)
  # Advance the iterators if we are restarting from an earlier checkpoint.
  # TODO(dusenberrymw): Look into checkpointing dataset state instead.

  # Makes sure log_eval_steps is same as steps_per_epoch. This is because
  # the precision matrix needs to be updated fully (at the end of each epoch)
  # when eval takes place.
  log_eval_steps = steps_per_epoch
  if first_step > 0:
    write_note('Advancing iterators after resuming from a checkpoint...')
    lr_iter = itertools.islice(lr_iter, first_step, None)
    train_iter = itertools.islice(train_iter, first_step, None)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl, reset_covmat_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter,
      reset_covmat_iter):

    with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
      # TODO(jereliu): Expand to allow precision matrix resetting.
      (opt_repl, states_repl, loss_value, train_loop_rngs,
       extra_measurements) = update_fn(
           opt_repl,
           states_repl,
           lr_repl,
           reset_covmat_repl,
           train_batch['image'],
           train_batch['labels'],
           rng=train_loop_rngs)

    if jax.process_index() == 0:
      profiler(step)

    # Checkpoint saving
    if train_utils.itstime(
        step, config.get('checkpoint_steps'), total_steps, process=0):
      write_note('Checkpointing...')
      chrono.pause()
      train_utils.checkpointing_timeout(checkpoint_writer,
                                        config.get('checkpoint_timeout', 1))
      accumulated_train_time = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      # For GP layer, we will also do the same for untrainable parameters
      # (`states`). This is ok since `random features` are frozen throughout
      # pre-training, and `precision matrix` is a finetuning-specific parameters
      # that will be re-learned in the finetuning task.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
      states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                             total_steps):
        write_note('Keeping a checkpoint copy...')
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint_data = checkpoint_utils.CheckpointData(
          optimizer=opt_cpu,
          fixed_model_states=states_cpu,
          train_loop_rngs=train_loop_rngs,
          accumulated_train_time=accumulated_train_time)
      checkpoint_writer = pool.apply_async(
          checkpoint_utils.checkpoint_trained_model,
          (checkpoint_data, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if train_utils.itstime(
        step, config.log_training_steps, total_steps, process=0):
      write_note('Reporting training progress...')
      train_loss = loss_value[0]  # Keep to return for reproducibility tests.
      timing_measurements, note = chrono.tick(step)
      write_note(note)
      train_measurements = {}
      train_measurements.update({
          'learning_rate': lr_repl[0],
          'training_loss': train_loss,
      })
      train_measurements.update(flax.jax_utils.unreplicate(extra_measurements))
      train_measurements.update(timing_measurements)
      writer.write_scalars(step, train_measurements)

    # Report validation performance
    if train_utils.itstime(step, log_eval_steps, total_steps):
      write_note('Evaluating on the validation set...')
      chrono.pause()

      all_eval_results = {}

      for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
        start_time = time.time()

        # Runs evaluation loop.
        results_arrs = {
            'y_true': [],
            'y_pred': [],
            'y_pred_entropy': []
        }

        for _, batch in zip(range(eval_steps), eval_iter):
          batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
              evaluation_fn(
                  opt_repl.target, states_repl, batch['image'],
                  batch['labels']))

          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.

          # Here we parse batch_metric_args to compute uncertainty metrics.
          logits, labels, _ = batch_metric_args
          logits = np.array(logits[0])
          probs = jax.nn.softmax(logits)

          # From one-hot to integer labels.
          int_labels = np.argmax(np.array(labels[0]), axis=-1)

          probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1))
          int_labels = int_labels.flatten()
          y_pred = probs[:, 1]
          results_arrs['y_true'].append(int_labels)
          results_arrs['y_pred'].append(y_pred)

          # Entropy is computed at the per-epoch level (see below).
          results_arrs['y_pred_entropy'].append(probs)

        results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                axis=0)
        results_arrs['y_pred'] = np.concatenate(
            results_arrs['y_pred'], axis=0).astype('float64')
        results_arrs['y_pred_entropy'] = vit_utils.entropy(
            np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1)

        time_elapsed = time.time() - start_time
        results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
        results_arrs['dataset_size'] = eval_steps * batch_size_eval

        all_eval_results[eval_name] = results_arrs

      per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
          dataset_split_to_containers=all_eval_results,
          is_deterministic=True,
          num_bins=15,
          return_per_pred_results=True
      )

      # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
      # dataset. Flatten this dict so we can pass to the writer and remove empty
      # entries.
      flattened_metric_results = {}
      for dic in metrics_results.values():
        for key, value in dic.items():
          if value is not None:
            flattened_metric_results[key] = value
      writer.write_scalars(step, flattened_metric_results)

      # Optionally log to wandb
      if config.use_wandb:
        wandb.log(metrics_results, step=step)

      # Save per-prediction metrics
      results_storage_utils.save_per_prediction_results(
          output_dir, step, per_pred_results, verbose=False)

      chrono.resume()

      # End of step.
    if config.get('testing_failure_step'):
      # Break early to simulate infra failures in test cases.
      if config.testing_failure_step == step:
        break

  write_note(f'Done!\n{chrono.note}')
  pool.close()
  pool.join()
  writer.close()

  if wandb_run is not None:
    wandb_run.finish()
Esempio n. 15
0
def eval_dataset_and_unshard(viewdir_mlp_model, viewdir_mlp_params,
                             rgb_features, directions, source_dataset,
                             scene_params):
    """Evaluates view-dependence on a sharded dataset and unshards the result.

  This function evaluates the view-dependence MLP on a dataset, adding back
  effects such as highlights.

  To make use of multi-host parallelism provided by JAX, this function takes as
  input a shardeds dataset, so each host only evaluates a slice of the data.
  Note that this function unshards the data before returning, which broadcasts
  the results back to all JAX hosts.

  Args:
    viewdir_mlp_model: A nerf.model_utils.MLP that predicts the per-ray
      view-dependent residual color.
    viewdir_mlp_params: A dict containing the MLP parameters for the per-ray
      view-dependence MLP.
    rgb_features: The RGB (+ features) input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array.
    directions: he direction vectors for the input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array.
    source_dataset: The nerf.datasets.Dataset we are evaluating.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).

  Returns:
    A list of color images, each stored as a (H, W, 3) numpy array.
  """
    @functools.partial(jax.pmap, in_axes=(0, 0), axis_name="batch")
    def pmap_eval_fn(rgb_and_feature_chunk, direction_chunk):
        """We need an inner function as only JAX types can be passed to a pmap."""
        residual = model_utils.viewdir_fn(viewdir_mlp_model,
                                          viewdir_mlp_params,
                                          rgb_and_feature_chunk,
                                          direction_chunk, scene_params)
        output = jnp.minimum(1.0,
                             rgb_and_feature_chunk[Ellipsis, 0:3] + residual)
        return jax.lax.all_gather(output, axis_name="batch")

    num_hosts = jax.host_count()
    num_local_devices = jax.local_device_count()
    num_images = source_dataset.camtoworlds.shape[0]
    num_batches = math.ceil(num_images / num_hosts)
    num_batches = num_local_devices * math.ceil(
        num_batches / num_local_devices)

    outputs = []
    for i in range(len(rgb_features)):
        # First, evaluate the loss in parallel across all devices.
        output_batch = pmap_eval_fn(rgb_features[i], directions[i])
        output_batch = np.reshape(output_batch[0],
                                  (num_hosts, num_local_devices,
                                   source_dataset.h, source_dataset.w, 3))

        # Then, make sure to populate the output array in the same order
        # as the original dataset.
        for j in range(num_local_devices):
            base_index = (i * num_local_devices + j) * num_hosts
            for k in range(num_hosts):
                gathered_dataset_index = base_index + k
                if gathered_dataset_index >= num_images:
                    break

                outputs.append(
                    np.array(output_batch[k][j]).reshape(
                        (source_dataset.h, source_dataset.w, 3)))

    return outputs
Esempio n. 16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    batch_size = FLAGS.batch_size
    learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    num_eval_steps = FLAGS.num_eval_steps
    eval_freq = FLAGS.eval_frequency
    max_target_length = FLAGS.max_target_length
    max_eval_target_length = FLAGS.max_eval_target_length
    random_seed = FLAGS.random_seed

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
        n_devices=jax.local_device_count(),
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        dynamic_batching=True,
        max_target_length=max_target_length,
        max_eval_target_length=max_eval_target_length)
    vocab_size = info_ds['text'].encoder.vocab_size
    encoder = info_ds['text'].encoder

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_target_length)

    transformer_lm_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': 512,
        'num_heads': 8,
        'num_layers': 6,
        'qkv_dim': 512,
        'mlp_dim': 2048,
        'max_len': max(max_target_length, max_eval_target_length)
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model, cache_def = create_model(init_rng, input_shape,
                                    transformer_lm_kwargs)
    optimizer = create_optimizer(model, learning_rate)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    p_pred_step = jax.pmap(predict_step, axis_name='batch')

    metrics_all = []
    tick = time.time()
    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and FLAGS.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f', step,
                         summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                train_summary_writer.scalar('steps per second', steps_per_sec,
                                            step)
                for key, val in summary.items():
                    train_summary_writer.scalar(key, val, step)
                train_summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            eval_summary['perplexity'] = jnp.clip(jnp.exp(
                eval_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('eval in step: %d, loss: %.4f', step,
                         eval_summary['loss'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(key, val, step)
                eval_summary_writer.flush()

            # Fast inference of prompt extension using trained LM.
            rng, subrng = jax.random.split(rng)
            pred_rngs = random.split(subrng, jax.local_device_count())
            prompt = jnp.array(encoder.encode(FLAGS.prompt))
            prompt = jax_utils.replicate(prompt)
            prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (1, FLAGS.max_predict_token_length)))
            predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
            predicted = tohost(predicted)
            exemplars = ''
            for n in range(predicted.shape[0]):
                exemplars += encoder.decode(predicted[n]) + '\n\n'
            if jax.host_id() == 0:
                eval_summary_writer.text('samples', exemplars, step)
                eval_summary_writer.flush()
Esempio n. 17
0
def build_sharded_dataset_for_view_dependence(source_dataset, atlas_t,
                                              atlas_block_indices_t,
                                              atlas_params, scene_params,
                                              grid_params):
    """Builds a dataset that we can run the view-dependence MLP on.

  We ray march through a baked SNeRG model to generate images with RGB colors
  and features. These serve as the input for the view-dependence MLP which adds
  back the effects such as highlights.

  To make use of multi-host parallelism provided by JAX, this function shards
  the dataset, so that each host contains only a slice of the data.

  Args:
    source_dataset: The nerf.datasets.Dataset we should compute data for.
    atlas_t: A tensorflow tensor containing the texture atlas.
    atlas_block_indices_t: A tensorflow tensor containing the indirection grid.
    atlas_params: A dict with params for building and rendering with
      the 3D texture atlas.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    grid_params: A dict with parameters describing the high-res voxel grid which
      the atlas is representing.

  Returns:
    rgb_data: The RGB (+ features) input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array.
    alpha_data: The alpha channel of the input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 1) numpy array.
    direction_data: The direction vectors for the input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array.
    ref_data: The reference RGB colors for each input data sample, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array.
  """

    num_hosts = jax.host_count()
    num_local_devices = jax.local_device_count()
    host_id = jax.host_id()
    num_images = source_dataset.camtoworlds.shape[0]
    num_batches = math.ceil(num_images / num_hosts)
    num_batches = num_local_devices * math.ceil(
        num_batches / num_local_devices)

    rgb_list = []
    alpha_list = []
    viewdir_list = []
    ref_list = []
    for i in range(num_batches):
        base_index = i * num_hosts
        dataset_index = base_index + host_id

        rgb = np.zeros(
            (source_dataset.h, source_dataset.w, scene_params["_channels"]),
            dtype=np.float32)
        alpha = np.zeros((source_dataset.h, source_dataset.w, 1),
                         dtype=np.float32)
        viewdirs = np.zeros((source_dataset.h, source_dataset.w, 3),
                            dtype=np.float32)

        if dataset_index < num_images:
            rgb, alpha = rendering.atlas_raymarch_image_tf(
                source_dataset.h, source_dataset.w, source_dataset.focal,
                source_dataset.camtoworlds[dataset_index], atlas_t,
                atlas_block_indices_t, atlas_params, scene_params, grid_params)
            _, _, viewdirs = datasets.rays_from_camera(
                scene_params["_use_pixel_centers"], source_dataset.h,
                source_dataset.w, source_dataset.focal,
                np.expand_dims(source_dataset.camtoworlds[dataset_index], 0))

        np_rgb = np.array(rgb).reshape(
            (source_dataset.h, source_dataset.w, scene_params["_channels"]))
        np_alpha = np.array(alpha).reshape(
            (source_dataset.h, source_dataset.w, 1))
        np_viewdirs = viewdirs.reshape((np_rgb.shape[0], np_rgb.shape[1], 3))
        if scene_params["white_bkgd"]:
            np_rgb[Ellipsis, 0:3] = np.ones_like(np_rgb[Ellipsis, 0:3]) * (
                1.0 - np_alpha) + np_rgb[Ellipsis, 0:3]

        rgb_list.append(np_rgb)
        alpha_list.append(np_alpha)
        viewdir_list.append(np_viewdirs)
        ref_list.append(source_dataset.images[dataset_index % num_images])

    rgb_data = np.stack(rgb_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w,
         scene_params["_channels"]))
    alpha_data = np.stack(alpha_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 1))
    viewdir_data = np.stack(viewdir_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 3))
    ref_data = np.stack(ref_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 3))

    return rgb_data, alpha_data, viewdir_data, ref_data
Esempio n. 18
0
def train_and_evaluate(random_seed, batch_size, learning_rate, num_train_steps,
                       num_eval_steps, eval_freq, max_target_length,
                       max_eval_target_length, weight_decay, data_dir,
                       model_dir, restore_checkpoints, save_checkpoints,
                       checkpoint_freq, max_predict_token_length,
                       sampling_temperature, sampling_top_k, prompt_str):
    """Executes model training and evaluation loop.
  
  Args:
    random_seed: Seed for initializing PRNG random seed.
    batch_size: Batch size for training.
    learning_rate: Learning rate for the Adam optimizer.
    num_train_steps: Number of training steps.
    num_eval_steps: Number of evaluation steps.
    eval_freq: Frequency of evaluation during training.
    max_target_length: Maximum length of training examples.
    max_eval_target_length: Maximum length of eval examples.
    weight_decay: Decay factor for AdamW-style weight decay.
    data_dir: Directory containing TFDS lm1b/subwords32k datasets.
    model_dir: Directory where to store model data.
    restore_checkpoints: Whether to restore from existing model checkpoints.
    save_checkpoints: Whether to save model checkpoints.
    checkpoint_freq: Save a checkpoint every these number of steps.
    max_predict_token_length: Maximum example text inference token length.
    sampling_temperature: Sampling temperature for language model inference.
    sampling_top_k: Top k cutoff for logit sampling.
    prompt_str: Prompt for language model sampling.
  """
    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(model_dir, 'eval'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
        n_devices=jax.local_device_count(),
        data_dir=data_dir,
        batch_size=batch_size,
        dynamic_batching=True,
        max_target_length=max_target_length,
        max_eval_target_length=max_eval_target_length)
    vocab_size = info_ds['text'].encoder.vocab_size
    encoder = info_ds['text'].encoder

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_target_length)

    transformer_lm_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': 512,
        'num_heads': 8,
        'num_layers': 6,
        'qkv_dim': 512,
        'mlp_dim': 2048,
        'max_len': max(max_target_length, max_eval_target_length)
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model, cache_def = create_model(init_rng, input_shape,
                                    transformer_lm_kwargs)
    optimizer = create_optimizer(model, learning_rate, weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    p_pred_step = jax.pmap(predict_step, axis_name='batch')

    metrics_all = []
    tick = time.time()
    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f', step,
                         summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                train_summary_writer.scalar('steps per second', steps_per_sec,
                                            step)
                for key, val in summary.items():
                    train_summary_writer.scalar(key, val, step)
                train_summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            eval_summary['perplexity'] = jnp.clip(jnp.exp(
                eval_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('eval in step: %d, loss: %.4f', step,
                         eval_summary['loss'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(key, val, step)
                eval_summary_writer.flush()

            # Fast inference of prompt extension using trained LM.
            rng, subrng = jax.random.split(rng)
            pred_rngs = random.split(subrng, jax.local_device_count())
            prompt = jnp.array(encoder.encode(prompt_str))
            prompt = jax_utils.replicate(prompt)
            prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
            cache = jax_utils.replicate(
                cache_def.initialize_cache((1, max_predict_token_length)))
            predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs,
                                    max_predict_token_length,
                                    sampling_temperature, sampling_top_k)
            predicted = tohost(predicted)
            exemplars = ''
            for n in range(predicted.shape[0]):
                exemplars += encoder.decode(predicted[n]) + '\n\n'
            if jax.host_id() == 0:
                eval_summary_writer.text('samples', exemplars, step)
                eval_summary_writer.flush()
Esempio n. 19
0
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    batch_size_per_device=gin.REQUIRED,
    validation_example_count=gin.REQUIRED,
):
    """Train the maze automaton.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
    batch_size_per_device: Batch size for each device.
    validation_example_count: How many examples to use when computing validation
      metrics.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())

    with contextlib.ExitStack() as exit_stack:
        logging.info("Setting up datasets...")
        raw_train_iterator = runner.build_sampling_iterator(
            dataset_paths["train_dataset"],
            example_type=graph_bundle.GraphBundle)

        raw_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count)

        # Add the example id into the example itself, so that we can use it to
        # randomly choose a goal.
        def reify_id(it):
            for item in it:
                yield dataclasses.replace(item,
                                          example=(item.example,
                                                   item.example_id))

        def reify_id_and_batch(it):
            return data_loading.batch(reify_id(it),
                                      (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        train_iterator = reify_id_and_batch(raw_train_iterator)
        valid_iterator_factory = (
            lambda: reify_id_and_batch(raw_valid_iterator_factory()))

        if prefetch:
            train_iterator = exit_stack.enter_context(
                data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        logging.info("Setting up model...")
        padding_config = maze_task.PADDING_CONFIG
        model_def = automaton_layer.FiniteStateGraphAutomaton.partial(
            static_metadata=padding_config.static_max_metadata,
            builder=maze_task.BUILDER)

        # Initialize parameters randomly.
        _, initial_params = model_def.init(
            jax.random.PRNGKey(int(time.time() * 1000)),
            graph_bundle.zeros_like_padded_example(
                padding_config).automaton_graph,
            dynamic_metadata=padding_config.static_max_metadata)

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        extra_artifacts = {
            "builder.pickle": maze_task.BUILDER,
        }

        return runner.training_loop(
            optimizer=optimizer,
            train_iterator=train_iterator,
            loss_fn=loss_fn,
            validation_fn=train_util.build_averaging_validator(
                loss_fn, valid_iterator_factory),
            extra_artifacts=extra_artifacts)
Esempio n. 20
0
def main(config, output_dir):

    seed = config.get('seed', 0)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    write_note('Initializing model...')
    logging.info('config.model = %s', config.model)
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.model)

    ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply)

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # params is a dict of the form:
        #   {'model_1': params_model_1, 'model_2': params_model_2, ...}
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        loss_as_str = config.get('loss', 'sigmoid_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)

        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, loss_as_str)(
            logits=ens_logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @functools.partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        loss_as_str = config.get('loss', 'softmax_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)
        label_indices = config.get('label_indices')
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        losses = getattr(train_utils, loss_as_str)(logits=ens_logits,
                                                   labels=labels,
                                                   reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @functools.partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        # Return shape [batch_size, representation_size * ensemble_size]. During
        # few-shot eval, a single linear regressor is applied over all dimensions.
        representation = []
        for p in params.values():
            _, outputs = model.apply({'params': flax.core.freeze(p)},
                                     images,
                                     train=False)
            representation += [outputs[config.fewshot.representation_layer]]
        representation = jnp.concatenate(representation, axis=1)
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    write_note('Load checkpoints...')
    ensemble_params = load_checkpoints(config)

    write_note('Replicating...')
    ensemble_params = flax.jax_utils.replicate(ensemble_params)

    if jax.process_index() == 0:
        writer.write_hparams(dict(config))

    write_note('Initializing few-shotters...')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}
    step = 1

    # Report validation performance.
    write_note('Evaluating on the validation set...')
    for val_name, val_ds in val_ds_splits.items():
        # Sets up evaluation metrics.
        ece_num_bins = config.get('ece_num_bins', 15)
        auc_num_bins = config.get('auc_num_bins', 1000)
        ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)
        calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)
        oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005,
                                                       num_bins=auc_num_bins)
        oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01,
                                                     num_bins=auc_num_bins)
        oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02,
                                                     num_bins=auc_num_bins)
        oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05,
                                                     num_bins=auc_num_bins)
        label_diversity = tf.keras.metrics.Mean()
        sample_diversity = tf.keras.metrics.Mean()
        ged = tf.keras.metrics.Mean()

        # Runs evaluation loop.
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))
        ncorrect, loss, nseen = 0, 0, 0
        for batch in val_iter:
            if val_name == 'cifar_10h':
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    cifar_10h_evaluation_fn(ensemble_params, batch['image'],
                                            batch['labels'], batch['mask']))
            else:
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    evaluation_fn(ensemble_params, batch['image'],
                                  batch['labels'], batch['mask']))
            # All results are a replicated array shaped as follows:
            # (local_devices, per_device_batch_size, elem_shape...)
            # with each local device's entry being identical as they got psum'd.
            # So let's just take the first one to the host as numpy.
            ncorrect += np.sum(np.array(batch_ncorrect[0]))
            loss += np.sum(np.array(batch_losses[0]))
            nseen += np.sum(np.array(batch_n[0]))
            if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                # Here we parse batch_metric_args to compute uncertainty metrics.
                # (e.g., ECE or Calibration AUC).
                logits, labels, _, masks = batch_metric_args
                masks = np.array(masks[0], dtype=np.bool)
                logits = np.array(logits[0])
                probs = jax.nn.softmax(logits)
                # From one-hot to integer labels, as required by ECE.
                int_labels = np.argmax(np.array(labels[0]), axis=-1)
                int_preds = np.argmax(logits, axis=-1)
                confidence = np.max(probs, axis=-1)
                for p, c, l, d, m, label in zip(probs, confidence, int_labels,
                                                int_preds, masks, labels[0]):
                    ece.add_batch(p[m, :], label=l[m])
                    calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
                    # TODO(jereliu): Extend to support soft multi-class probabilities.
                    oc_auc_0_5.add_batch(d[m],
                                         label=l[m],
                                         custom_binning_score=c[m])
                    oc_auc_1.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_2.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_5.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])

                    if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                        batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                            label[m], p[m, :], config.num_classes)
                        label_diversity.update_state(batch_label_diversity)
                        sample_diversity.update_state(batch_sample_diversity)
                        ged.update_state(batch_ged)

        val_loss[val_name] = loss / nseen  # Keep for reproducibility tests.
        val_measurements = {
            f'{val_name}_prec@1': ncorrect / nseen,
            f'{val_name}_loss': val_loss[val_name],
        }
        if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
            val_measurements[f'{val_name}_ece'] = ece.result()['ece']
            val_measurements[f'{val_name}_calib_auc'] = calib_auc.result(
            )['calibration_auc']
            val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
            )['collaborative_auc']
        writer.write_scalars(step, val_measurements)

        if val_name == 'cifar_10h' or val_name == 'imagenet_real':
            cifar_10h_measurements = {
                f'{val_name}_label_diversity': label_diversity.result(),
                f'{val_name}_sample_diversity': sample_diversity.result(),
                f'{val_name}_ged': ged.result(),
            }
            writer.write_scalars(step, cifar_10h_measurements)

    # OOD eval
    # Entries in the ood_ds dict include:
    # (ind_dataset, ood_dataset1, ood_dataset2, ...).
    # OOD metrics are computed using ind_dataset paired with each of the
    # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
    # is also included in the ood_ds.
    if ood_ds and config.ood_methods:
        ood_measurements = ood_utils.eval_ood_metrics(ood_ds,
                                                      ood_ds_names,
                                                      config.ood_methods,
                                                      evaluation_fn,
                                                      ensemble_params,
                                                      n_prefetch=config.get(
                                                          'prefetch_to_device',
                                                          1))
        writer.write_scalars(step, ood_measurements)

    if 'fewshot' in config and fewshotter is not None:
        # Compute few-shot on-the-fly evaluation.
        write_note('Few-shot evaluation...')
        # Keep `results` to return for reproducibility tests.
        fewshot_results, best_l2 = fewshotter.run_all(ensemble_params,
                                                      config.fewshot.datasets)

        # TODO(dusenberrymw): Remove this once fewshot.py is updated.
        def make_writer_measure_fn(step):
            def writer_measure(name, value):
                writer.write_scalars(step, {name: value})

            return writer_measure

        fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results,
                                best_l2)

    write_note('Done!')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return val_loss, fewshot_results
Esempio n. 21
0
def parallel_train_loop(key,
                        init_params,
                        loss_fn,
                        summarize_fn=default_summarize,
                        lr=1e-4,
                        num_steps=int(1e5),
                        summarize_every=100,
                        checkpoint_every=5000,
                        clobber_checkpoint=False,
                        logdir="/tmp/lda_inference"):

  loss_fn = jax.jit(loss_fn)

  optimizer_def = optim.Adam()
  local_optimizer = optimizer_def.create(init_params)
  local_optimizer = util.maybe_load_checkpoint(
      logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint)
  first_step = local_optimizer.state.step
  repl_optimizer = jax_utils.replicate(local_optimizer)

  lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

  @functools.partial(jax.pmap, axis_name="batch")
  def train_step(optimizer, key):
    key, subkey = jax.random.split(key)
    loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key)
    loss_grad = jax.lax.pmean(loss_grad, "batch")
    new_optimizer = optimizer.apply_gradient(
        loss_grad, learning_rate=lr_fn(optimizer.state.step))
    return new_optimizer, subkey

  sw = SummaryWriter(logdir)

  repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange(jax.local_device_count()))
  start = timeit.default_timer()
  for t in range(first_step, num_steps):
    if t % checkpoint_every == 0 and t != first_step:
      optimizer = jax_utils.unreplicate(repl_optimizer)
      checkpoints.save_checkpoint(logdir,
                                  optimizer,
                                  optimizer.state.step, keep=3)
      print("Checkpoint saved for step %d" % optimizer.state.step)

    repl_optimizer, repl_key = train_step(repl_optimizer, repl_key)

    if t % summarize_every == 0:
      key, subkey = jax.random.split(jax_utils.unreplicate(repl_key))
      optimizer = jax_utils.unreplicate(repl_optimizer)
      loss_val = loss_fn(optimizer.target, key)
      print("Step %d loss: %0.4f" % (t, loss_val))
      sw.scalar("loss", loss_val, step=t)
      summarize_fn(sw, t, optimizer.target, subkey)
      end = timeit.default_timer()
      if t == 0:
        steps_per_sec = 1. / (end - start)
      else:
        steps_per_sec = summarize_every / (end - start)
      print("Steps/sec: %0.2f" % steps_per_sec)
      sw.scalar("steps_per_sec", steps_per_sec, step=t)
      start = end
      sw.flush()
      sys.stdout.flush()
Esempio n. 22
0
def run_train(run_configuration):
    """Runs the training workflow."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    log_dir = os.path.join(run_dir, 'train')
    checkpoint_path = run_configuration.original_checkpoint_path

    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)

    summary_freq = config.logging.summary_freq
    metrics_all = []
    tick = time.time()
    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rngs, logits, state = train_step(
            optimizer, train_inputs, dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(
                    checkpoint_dir, jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if summary_freq and step % summary_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train step: %d, loss: %.4f', step, summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = summary_freq / (tock - tick)
                examples_per_sec = denominator / (tock - tick)
                tick = tock
                summary_writer.scalar('per-second/steps', steps_per_sec, step)
                summary_writer.scalar('per-second/examples', examples_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar(key, val, step)

                adapter.write_summaries(example, logits, summary_writer, info,
                                        step, state)

                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []
Esempio n. 23
0
def predict_and_compute_score(*, p_pred_step, p_init_cache, target,
                              predict_ds,
                              decode_io,
                              decode_program,
                              beam_size,
                              num_partial_programs,
                              use_best_first_search = False,
                              slow_decode = False):
  """Generates program and computes score."""
  n_devices = jax.local_device_count()

  pred_acc = 0
  pred_denominator = 0

  ios, targets, predictions = [], [], []
  for batches in predict_ds.as_numpy_iterator():
    pred_batch = batches
    # Handle final odd-sized batch by padding instead of dropping it.
    cur_pred_batch_size = pred_batch[0].shape[0]
    if cur_pred_batch_size % n_devices:
      padded_size = int(
          np.ceil(cur_pred_batch_size / n_devices) * n_devices)
      # pylint: disable=cell-var-from-loop
      pred_batch = jax.tree_map(
          lambda x: pad_examples(x, padded_size), pred_batch)
    inputs, outputs, programs, split_outputs = common_utils.shard(pred_batch)

    cache = (p_init_cache(inputs, outputs, programs[:, :, 0])
             if not slow_decode else None)
    predicted, log_probs = p_pred_step(target,
                                       inputs,
                                       outputs,
                                       cache,
                                       beam_size,
                                       split_outputs=split_outputs)
    predicted, log_probs = map(tohost, (predicted, log_probs))
    inputs, outputs, programs = map(tohost, (inputs, outputs, programs))

    pred_denominator += programs.shape[0]
    for i, partial_beams in enumerate(predicted):
      inps, outs = decode_io(inputs[i], outputs[i])

      # Find the best orderings of partial programs.
      # partial_seqs shape == [n_beam, n_partial]
      if use_best_first_search:
        partial_seqs = best_first_search(log_probs[i], beam_size)
      else:
        partial_seqs = beam_decoder(log_probs[i], beam_size)
      # beams shape == [n_beam, n_partial, length]
      beams = partial_beams[np.arange(num_partial_programs), partial_seqs]

      # Execute predicted programs on i/o examples.
      p, p_score = compute_score(beams, inps, outs, decode_program)
      if p_score >= len(inps):
        pred_acc += 1
      ios.append(' ; '.join(map(str, zip(inps, outs))))
      targets.append(decode_program(programs[i]).to_string())
      try:
        predictions.append(p.to_string())
      except:  # pylint: disable=bare-except
        predictions.append('')
      logging.info('ios: %s', ios[-1])
      logging.info('target: %s', targets[-1])
      beams_log = []
      for beam in beams:
        try:
          beams_log.append(decode_program(beam).to_string())
        except:  # pylint: disable=bare-except
          beams_log.append('None')
      logging.info('predicted beam: %s', '\n'.join(beams_log))

  all_pred_acc, all_pred_denominator = per_host_sum_pmap(
      jax.tree_map(np.array, (pred_acc, pred_denominator)))

  # Record beam search results as text summaries.
  message = []
  for n in np.random.choice(np.arange(len(predictions)), 8):
    text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
            f'predicted: {predictions[n]}\n\n')
    message.append(text)

  return all_pred_acc / all_pred_denominator, message
Esempio n. 24
0
        split=process_split,
        batch_dims=(),
        rng=rng,
        filter_fn=None,
        preprocess_fn=preprocess_fn,
        decoders={"image": tfds.decode.SkipDecoding()},
        cache=cache == "loaded",
        num_epochs=num_epochs if not repeat_after_batching else 1,
        shuffle=shuffle,
        shuffle_buffer_size=shuffle_buffer_size,
        prefetch_size=0,
        pad_up_to_batches=None,
        drop_remainder=drop_remainder,
    )

    num_devices = jax.local_device_count()
    if drop_remainder:
        # If we're dropping the remainder, we can take the fast path of double
        # batching to [num_devices, batch_size_per_device] and then adding a mask of
        # ones for the two batch dimensions.
        batch_size_per_device = process_batch_size // num_devices
        batch_dims = [num_devices, batch_size_per_device]
        for batch_size in reversed(batch_dims):
            dataset = dataset.batch(batch_size, drop_remainder=True)

        dataset = dataset.map(lambda xs: _add_mask(xs, 2),
                              num_parallel_calls=tf.data.AUTOTUNE)
    else:
        # If we're not dropping the remainder, then we define a flattened batch size
        # that would divide evenly across devices, and then batch to that size with
        # drop_remainder=False. Then we add a mask of ones for the examples given,
Esempio n. 25
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Grab pretrain text data
    if FLAGS.target_text:
        targets_decoded_pt = []
        for i in range(1, 9):
            with tf.io.gfile.GFile(FLAGS.target_text % i, 'rb') as f:
                pt_targs_tmp = pickle.load(f)
            targets_decoded_pt.extend(pt_targs_tmp)
    else:
        train_ds, (encoder_in,
                   encoder_tgt) = input_pipeline.get_wmt_is_datasets(
                       n_devices=jax.local_device_count(),
                       dataset_name=FLAGS.dataset_name,
                       shard_idx=jax.process_index(),
                       shard_count=jax.process_count(),
                       data_dir=FLAGS.data_dir,
                       vocab_path=FLAGS.vocab_path,
                       target_vocab_size=32000,
                       batch_size=1024,
                       max_length=256,
                       paracrawl_size=FLAGS.paracrawl_size,
                       split_tokenizer=FLAGS.split_tokenizer)

        train_data = iter(train_ds)
        eos_id = decode.EOS_ID

        def decode_tokens(encoder, toks):
            valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
            return encoder.detokenize(valid_toks).numpy().decode('utf-8')

        targets = []
        inputs = []
        for x in train_data:
            trg = x['targets']._numpy()  # pylint:disable=protected-access
            ins = x['inputs']._numpy()  # pylint:disable=protected-access
            targets.append(trg)
            inputs.append(ins)

        # flatten targets_decoded_pt
        # pylint:disable=g-complex-comprehension
        targets_flat = [t for batch_t in targets for t in batch_t]
        inputs_flat = [t for batch_t in inputs for t in batch_t]
        # pylint:enable=g-complex-comprehension

        # decode only the slice for this one
        targets_decoded_pt = []
        start = PROC_SIZE * FLAGS.slice
        end = PROC_SIZE * (FLAGS.slice + 1)
        if FLAGS.slice == 14:
            end = 9999999
        for i, x in enumerate(targets_flat[start:end]):
            if FLAGS.clf_inputs:
                input_decode = decode_tokens(encoder_in,
                                             inputs_flat[i + start])
            if FLAGS.clf_targets:
                target_decode = decode_tokens(encoder_tgt, x)
            if FLAGS.clf_inputs and FLAGS.clf_targets:
                decode_tok = input_decode + ' [SEP] ' + target_decode
            else:
                decode_tok = target_decode if FLAGS.clf_targets else input_decode
            targets_decoded_pt.append(decode_tok)

    # Load model
    cache_dir = '/tmp/'  # model weights get temporarily written to this directory
    path = FLAGS.bert_base_dir
    trained_path = FLAGS.bert_clf_dir
    config = transformers.BertConfig.from_pretrained(os.path.join(
        trained_path, 'config.json'),
                                                     num_labels=2,
                                                     cache_dir=cache_dir)
    tokenizer = transformers.BertTokenizer.from_pretrained(path,
                                                           cache_dir=cache_dir)
    model = transformers.TFBertForSequenceClassification.from_pretrained(
        os.path.join(trained_path, 'tf_model.h5'),
        config=config,
        cache_dir=cache_dir)

    if FLAGS.target_text:
        # If we read the entire dataset from text, select the slice to encode
        start = PROC_SIZE * FLAGS.slice
        end = PROC_SIZE * (FLAGS.slice + 1)
        if FLAGS.slice == 14:
            end = 9999999
        input_targets = targets_decoded_pt[start:end]
    else:
        # the targets were decoded above so just use the ones that were decoded
        input_targets = targets_decoded_pt
    encoding = tokenizer(input_targets,
                         return_tensors='tf',
                         padding=True,
                         truncation=True,
                         max_length=512)

    train_dataset = tf.data.Dataset.from_tensor_slices((dict(encoding), ))
    batch_size = 256
    if FLAGS.clf_inputs and FLAGS.clf_targets:
        # multiling model is larger
        batch_size = 128
    train_dataset = train_dataset.batch(batch_size)
    logits = model.predict(train_dataset)

    probs = softmax(logits.logits, axis=1)

    clf_score_name = FLAGS.save_dir + '/CLR_scores_' + str(
        FLAGS.slice) + '.csv'
    with tf.io.gfile.GFile(clf_score_name, 'w') as f:
        writer = csv.writer(f)
        for p in probs:
            writer.writerow([p[1]])
Esempio n. 26
0
def barrier():
  """MPI-like barrier."""
  jax.device_get(_barrier(jnp.ones((jax.local_device_count(),))))
Esempio n. 27
0
 def create_buffers(self, name, param):
     """Prepares all momentum buffers for each parameter."""
     state = {'step': jnp.zeros(jax.local_device_count())}
     if self.get_hyper(name, 'momentum') is not None:
         state['momentum'] = jnp.zeros_like(param)
     return state
Esempio n. 28
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.alsologtostderr = True

    train_split = dataset.Split.from_string(FLAGS.train_split)
    eval_split = dataset.Split.from_string(FLAGS.eval_split)

    # The total batch size is the batch size accross all hosts and devices. In a
    # multi-host training setup each host will only see a batch size of
    # `total_train_batch_size / jax.host_count()`.
    total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count()
    num_train_steps = ((train_split.num_examples * FLAGS.train_epochs) //
                       total_train_batch_size)

    local_device_count = jax.local_device_count()
    train_dataset = dataset.load(
        train_split,
        is_training=True,
        batch_dims=[local_device_count, FLAGS.train_device_batch_size],
        bfloat16=FLAGS.train_bfloat16,
        transpose=FLAGS.dataset_transpose)

    if jax.default_backend() == 'gpu':
        # TODO(tomhennigan): This could be removed if XLA:GPU's allocator changes.
        train_dataset = dataset.double_buffer(train_dataset)

    # For initialization we need the same random key on each device.
    rng = jax.random.PRNGKey(FLAGS.train_init_random_seed)
    rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape)
    # Initialization requires an example input.
    batch = next(train_dataset)
    params, state, opt_state = jax.pmap(make_initial_state)(rng, batch)

    # Print a useful summary of the execution of our module.
    summary = hk.experimental.tabulate(train_step)(params, state, opt_state,
                                                   batch)
    for line in summary.split('\n'):
        logging.info(line)

    eval_every = FLAGS.train_eval_every
    log_every = FLAGS.train_log_every

    with time_activity('train'):
        for step_num in range(num_train_steps):
            # Take a single training step.
            with jax.profiler.StepTraceContext('train', step_num=step_num):
                params, state, opt_state, train_scalars = (train_step(
                    params, state, opt_state, next(train_dataset)))

            # By default we do not evaluate during training, but you can configure
            # this with a flag.
            if eval_every > 0 and step_num and step_num % eval_every == 0:
                with time_activity('eval during train'):
                    eval_scalars = evaluate(eval_split, params, state)
                logging.info('[Eval %s/%s] %s', step_num, num_train_steps,
                             eval_scalars)

            # Log progress at fixed intervals.
            if step_num and step_num % log_every == 0:
                train_scalars = jax.tree_map(lambda v: np.mean(v).item(),
                                             jax.device_get(train_scalars))
                logging.info('[Train %s/%s] %s', step_num, num_train_steps,
                             train_scalars)

    # Once training has finished we run eval one more time to get final results.
    with time_activity('final eval'):
        eval_scalars = evaluate(eval_split, params, state)
    logging.info('[Eval FINAL]: %s', eval_scalars)
Esempio n. 29
0
def main(_):

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    # Performance gains on TPU by switching to hardware bernoulli.
    def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None):
        lax_key = jax.lax.tie_in(rng_key, 0.0)
        return jax.lax.rng_uniform(lax_key, 1.0, shape) < p

    def set_hardware_bernoulli():
        jax.random.bernoulli = hardware_bernoulli

    set_hardware_bernoulli()

    # As we gridsearch the weight decay and the learning rate, we add them to the
    # output directory path so that each model has its own directory to save the
    # results in. We also add the `run_seed` which is "gridsearched" on to
    # replicate an experiment several times.
    output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate),
                                     'wd_' + str(FLAGS.weight_decay),
                                     'rho_' + str(FLAGS.sam_rho),
                                     'seed_' + str(FLAGS.run_seed))

    output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)

    if not gfile.exists(output_dir):
        gfile.makedirs(output_dir)

    num_devices = jax.local_device_count() * jax.host_count()
    assert FLAGS.batch_size % num_devices == 0
    local_batch_size = FLAGS.batch_size // num_devices
    info = 'Total batch size: {} ({} x {} replicas)'.format(
        FLAGS.batch_size, local_batch_size, num_devices)
    logging.info(info)

    if FLAGS.dataset == 'cifar10':
        if FLAGS.from_pretrained_checkpoint:
            image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        else:
            image_size = None
        dataset_source = dataset_source_lib.Cifar10(
            FLAGS.batch_size // jax.host_count(),
            FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations,
            image_size=image_size)
    elif FLAGS.dataset == 'cifar100':
        if FLAGS.from_pretrained_checkpoint:
            image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        else:
            image_size = None
        dataset_source = dataset_source_lib.Cifar100(
            FLAGS.batch_size // jax.host_count(),
            FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations,
            image_size=image_size)

    elif FLAGS.dataset == 'fashion_mnist':
        dataset_source = dataset_source_lib.FashionMnist(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset == 'svhn':
        dataset_source = dataset_source_lib.SVHN(
            FLAGS.batch_size, FLAGS.image_level_augmentations,
            FLAGS.batch_level_augmentations)
    elif FLAGS.dataset == 'imagenet':
        imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name)
        dataset_source = dataset_source_imagenet.Imagenet(
            FLAGS.batch_size // jax.host_count(), imagenet_image_size,
            FLAGS.image_level_augmentations)
    else:
        raise ValueError('Dataset not recognized.')

    if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset:
        if image_size is None or 'svhn' in FLAGS.dataset:
            image_size = 32
        num_channels = 3
        num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
    elif FLAGS.dataset == 'fashion_mnist':
        image_size = 28  # For Fashion Mnist
        num_channels = 1
        num_classes = 10
    elif FLAGS.dataset == 'imagenet':
        image_size = imagenet_image_size
        num_channels = 3
        num_classes = 1000
    else:
        raise ValueError('Dataset not recognized.')

    try:
        model, state = load_imagenet_model.get_model(FLAGS.model_name,
                                                     local_batch_size,
                                                     image_size, num_classes)
    except load_imagenet_model.ModelNameError:
        model, state = load_model.get_model(FLAGS.model_name, local_batch_size,
                                            image_size, num_classes,
                                            num_channels)

    # Learning rate will be overwritten by the lr schedule, we set it to zero.
    optimizer = flax_training.create_optimizer(model, 0.0)

    flax_training.train(optimizer, state, dataset_source, output_dir,
                        FLAGS.num_epochs)
def main(config, output_dir):
    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('dataset_dir'):
        logging.info('data_dir=%s', config.dataset_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset, config.val_split, config.pp_eval,
                       config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = None
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.het_vision_transformer(num_classes=config.num_classes,
                                             **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)

        rng, diag_noise_rng, standard_noise_rng = jax.random.split(rng, num=3)
        init_rngs = {
            'params': rng,
            'diag_noise_samples': diag_noise_rng,
            'standard_norm_noise_samples': standard_noise_rng
        }

        params = flax.core.unfreeze(
            model.init(init_rngs, dummy_input, train=False))['params']

        head = ('multiclass_head' if config.get('model', {}).get('multiclass')
                else 'multilabel_head')
        # Set bias in the head to a low value, such that loss is small initially.
        if head in params:
            params[head]['loc_layer']['bias'] = jnp.full_like(
                params[head]['loc_layer']['bias'],
                config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params[head]['loc_layer']['kernel'] = jnp.full_like(
                params[head]['loc_layer']['kernel'], 0)
            params[head]['diag_layer']['kernel'] = jnp.full_like(
                params[head]['diag_layer']['kernel'], 0)
            params[head]['diag_layer']['bias'] = jnp.full_like(
                params[head]['diag_layer']['bias'], 0)

            if 'scale_layer_homoscedastic' in params[head]:
                params[head]['scale_layer_homoscedastic'][
                    'kernel'] = jnp.full_like(
                        params[head]['scale_layer_homoscedastic']['kernel'], 0)
                params[head]['scale_layer_homoscedastic'][
                    'bias'] = jnp.full_like(
                        params[head]['scale_layer_homoscedastic']['bias'], 0)
            if 'scale_layer_heteroscedastic' in params[head]:
                params[head]['scale_layer_heteroscedastic'][
                    'kernel'] = jnp.full_like(
                        params[head]['scale_layer_heteroscedastic']['kernel'],
                        0)
                params[head]['scale_layer_heteroscedastic'][
                    'bias'] = jnp.full_like(
                        params[head]['scale_layer_heteroscedastic']['bias'], 0)

        return params

    (rng, rng_init, rng_dropout, diag_noise_rng,
     standard_noise_rng) = jax.random.split(rng, num=5)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False,
                                  rngs={
                                      'dropout':
                                      rng_dropout,
                                      'diag_noise_samples':
                                      diag_noise_rng,
                                      'standard_norm_noise_samples':
                                      standard_noise_rng
                                  })
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False,
                                  rngs={
                                      'dropout':
                                      rng_dropout,
                                      'diag_noise_samples':
                                      diag_noise_rng,
                                      'standard_norm_noise_samples':
                                      standard_noise_rng
                                  })
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False,
                                 rngs={
                                     'dropout':
                                     rng_dropout,
                                     'diag_noise_samples':
                                     diag_noise_rng,
                                     'standard_norm_noise_samples':
                                     standard_noise_rng
                                 })
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)
    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))
        rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split(
            rng_model_local, num=3)

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={
                                        'dropout':
                                        rng_model_local,
                                        'diag_noise_samples':
                                        diag_noise_rng,
                                        'standard_norm_noise_samples':
                                        standard_noise_rng
                                    })
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    if config.get('model.multiclass', False):
        default_reinit_params = []
    else:
        default_reinit_params = [
            'head/scale_layer_homoscedastic/kernel',
            'head/scale_layer_homoscedastic/bias',
            'head/scale_layer_heteroscedastic/kernel',
            'head/scale_layer_heteroscedastic/bias', 'head/loc_layer/kernel',
            'head/diag_layer/kernel', 'head/loc_layer/bias',
            'head/diag_layer/bias'
        ]
        default_reinit_params = (
            default_reinit_params +
            list(map(lambda k: 'multilabel_' + k, default_reinit_params)) +
            list(map(lambda k: 'multiclass_' + k, default_reinit_params)))

    rng, train_loop_rngs = jax.random.split(rng)

    if config.get('only_eval', False) or not config.get('reint_head', True):
        default_reinit_params = []

    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=default_reinit_params,
        config=config,
    )
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl,
                    lr_repl,
                    train_batch['image'],
                    train_batch['labels'],
                    rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                optimizer=opt_cpu,
                train_loop_rngs=train_loop_rngs,
                accumulated_train_time=accumulated_train_time)
            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, val_ds in val_ds_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                val_iter = input_utils.start_input_pipeline(
                    val_ds, config.get('prefetch_to_device', 1))
                ncorrect, loss, nseen = 0, 0, 0
                for batch in val_iter:
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))

                    # Here we parse batch_metric_args to compute uncertainty metrics.
                    # (e.g., ECE or Calibration AUC).
                    logits, labels, _, masks = batch_metric_args
                    masks = np.array(masks[0], dtype=np.bool)
                    logits = np.array(logits[0])
                    probs = jax.nn.softmax(logits)
                    # From one-hot to integer labels, as required by ECE.
                    int_labels = np.argmax(np.array(labels[0]), axis=-1)
                    int_preds = np.argmax(logits, axis=-1)
                    confidence = np.max(probs, axis=-1)
                    for p, c, l, d, m, label in zip(probs, confidence,
                                                    int_labels, int_preds,
                                                    masks, labels[0]):
                        ece.add_batch(p[m, :], label=l[m])
                        calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
                        # TODO(jereliu): Extend to support soft multi-class probabilities.
                        oc_auc_0_5.add_batch(d[m],
                                             label=l[m],
                                             custom_binning_score=c[m])
                        oc_auc_1.add_batch(d[m],
                                           label=l[m],
                                           custom_binning_score=c[m])
                        oc_auc_2.add_batch(d[m],
                                           label=l[m],
                                           custom_binning_score=c[m])
                        oc_auc_5.add_batch(d[m],
                                           label=l[m],
                                           custom_binning_score=c[m])

                        if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                            batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                label[m], p[m, :], config.num_classes)
                            label_diversity.update_state(batch_label_diversity)
                            sample_diversity.update_state(
                                batch_sample_diversity)
                            ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1':
                    ncorrect / nseen,
                    f'{val_name}_loss':
                    val_loss[val_name],
                    f'{val_name}_ece':
                    ece.result()['ece'],
                    f'{val_name}_calib_auc':
                    calib_auc.result()['calibration_auc'],
                    f'{val_name}_oc_auc_0.5%':
                    oc_auc_0_5.result()['collaborative_auc'],
                    f'{val_name}_oc_auc_1%':
                    oc_auc_1.result()['collaborative_auc'],
                    f'{val_name}_oc_auc_2%':
                    oc_auc_2.result()['collaborative_auc'],
                    f'{val_name}_oc_auc_5%':
                    oc_auc_5.result()['collaborative_auc'],
                }
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # There are two entries in the ood_ds dict (in-dist, ood), and this
            # section computes metrics using both pieces. This is in contrast to
            # normal validation eval above where we eval metrics separately for each
            # val split in val_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds,
                    ood_ds_names,
                    config.ood_methods,
                    evaluation_fn,
                    opt_repl.target,
                    n_prefetch=config.get('prefetch_to_device', 1))
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results