Beispiel #1
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    def encode_strings(strs, max_len):
        tokenized_batch = np.zeros((len(strs), max_len), np.int32)
        for i, s in enumerate(strs):
            toks = encoder.tokenize(s).numpy()
            # Remove EOS token in prompt.
            tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1]
        return tokenized_batch

    tokenized_prompts = encode_strings([config.prompts],
                                       config.max_predict_length)

    logging.info("Initializing model, optimizer, and step functions.")
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    rng, inference_rng = random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.TransformerLM(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step fn.
    p_train_step = jax.pmap(functools.partial(
        train_step, config=train_config, learning_rate_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name="batch")

    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          temperature=config.sampling_temperature,
                          top_k=config.sampling_top_k),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]),
                                                     a_max=1.0e4)
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    # (clipped) perplexity after averaging log-perplexitie
                    eval_results["perplexity"] = jnp.clip(jnp.exp(
                        eval_results["loss"]),
                                                          a_max=1.0e4)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("generate_text"):
                    exemplars = generate_prediction(
                        p_pred_step=p_pred_step,
                        target=optimizer.target,
                        tokenized_prompts=tokenized_prompts,
                        eos_id=eos_id,
                        inference_rng=inference_rng,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.host_id(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
Beispiel #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.process_index() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length,
        tokenizer=config.tokenizer,
        vocab_file_path=FLAGS.vocab_file_path)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': 2,
        'classifier_pool': config.pooling_mode
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model = train_utils.get_model(model_type, create_model, model_kwargs,
                                  init_rng, input_shape)

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints or FLAGS.test_only:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    def run_eval(eval_ds, num_eval_steps=-1):
        eval_metrics = []
        eval_iter = iter(eval_ds)
        if num_eval_steps == -1:
            num_iter = itertools.count()
        else:
            num_iter = range(num_eval_steps)
        for _, eval_batch in zip(num_iter, eval_iter):
            # pylint: disable=protected-access
            eval_batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), eval_batch))
            # pylint: enable=protected-access
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        # Calculate (clipped) perplexity after averaging log-perplexities:
        eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']),
                                              a_max=1.0e4)
        return eval_summary

    if FLAGS.test_only:
        with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'),
                               'w') as f:
            test_summary = run_eval(test_ds)
            json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f)
        return

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        # logging.info(batch)
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.process_index() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.process_index() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_summary = run_eval(eval_ds, num_eval_steps)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.process_index() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()

            # Test eval
            # Eval Metrics
            logging.info('Testing...')
            test_summary = run_eval(test_ds, num_eval_steps)
            logging.info('test in step: %d, loss: %.4f, acc: %.4f', step,
                         test_summary['loss'], test_summary['accuracy'])
            if jax.process_index() == 0:
                for key, val in test_summary.items():
                    summary_writer.scalar(f'test_{key}', val, step)
                summary_writer.flush()
Beispiel #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

  rng = random.PRNGKey(0)

  image_size = 224

  batch_size = FLAGS.batch_size
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  platform = jax.local_devices()[0].platform

  if FLAGS.half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
      input_dtype = tf.bfloat16
    else:
      model_dtype = jnp.float16
      input_dtype = tf.float16
  else:
    model_dtype = jnp.float32
    input_dtype = tf.float32

  train_iter = create_input_iter(
      local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache)
  eval_iter = create_input_iter(
      local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache)

  num_epochs = FLAGS.num_epochs
  steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
  steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * num_epochs

  base_learning_rate = FLAGS.learning_rate * batch_size / 256.
  base_learning_rate = base_learning_rate / FLAGS.loss_scaling

  model, model_state = create_model(
      rng, device_batch_size, image_size, model_dtype)
  optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model)
  state = TrainState(step=0, optimizer=optimizer, model_state=model_state)
  del model, model_state  # do not keep a copy of the initial model

  state = restore_checkpoint(state)
  step_offset = int(state.step)  # step_offset > 0 if restarting from checkpoint
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_fn(
      base_learning_rate, steps_per_epoch, num_epochs)

  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  epoch_metrics = []
  t_loop_start = time.time()
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    state, metrics = p_train_step(state, batch)
    epoch_metrics.append(metrics)
    if (step + 1) % steps_per_epoch == 0:
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f',
                   epoch, summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
      t_loop_start = time.time()
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                   epoch, summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = sync_batch_stats(state)
      save_checkpoint(state)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Beispiel #4
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.LatentTransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        latent_vocab_size=FLAGS.latent_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        deterministic=False,
        decode=False,
        c=FLAGS.c,
        train_vq=True,
        commitment_cost_vq=FLAGS.commitment_cost_vq,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True, train_vq=False)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          train_vq=False,
                                          decode=True)

    # Latent Predictor.
    lp_train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=FLAGS.latent_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    lp_eval_config = lp_train_config.replace(deterministic=True)
    lp_predict_config = lp_train_config.replace(shift=False,
                                                deterministic=True,
                                                decode=True)

    rng = jax.random.PRNGKey(0)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.LatentProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))
    lp_m = models.ProgramTransformer(lp_eval_config)
    lp_initial_variables = jax.jit(lp_m.init)(init_rng,
                                              jnp.ones(io_shape, jnp.float32),
                                              jnp.ones(io_shape, jnp.float32),
                                              jnp.ones(program_shape,
                                                       jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])
    lp_optimizer = optimizer_def.create(lp_initial_variables['params'])

    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=initial_variables['vqvae'],
                       lp_optimizer=lp_optimizer)
    # Don't keep a copy of the initial model.
    del initial_variables, lp_initial_variables

    train_rngs = jax.random.split(rng, jax.local_device_count())

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        state = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state)
        # Grab last step.
        start_step = int(state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        bos_token=bos_token,
        eos_token=eos_token,
        learning_rate_fn=learning_rate_fn,
        config=train_config,
        lp_config=lp_train_config),
                            axis_name='batch',
                            static_broadcasted_argnums=(4, ))
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             bos_token=bos_token,
                                             eos_token=eos_token,
                                             config=eval_config,
                                             lp_config=lp_eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        lp_config=lp_predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        bos_token=bos_token,
        eos_token=eos_token,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        lp_config=lp_predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(5, ))

    metrics_all = []
    latent_metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        state, metrics, latent_metrics, train_rngs = p_train_step(
            state,
            inputs,
            outputs,
            programs,
            step <= FLAGS.num_pretrain_steps,
            train_rng=train_rngs)
        metrics, latent_metrics = jax.tree_map(np.array,
                                               (metrics, latent_metrics))
        metrics_all.append(metrics)
        latent_metrics_all.append(latent_metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(state), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        latent_metrics_all = common_utils.get_metrics(latent_metrics_all)
        metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary.update(
            jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums))

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []
        latent_metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        latent_eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)
            all_metrics = p_eval_step(state, inputs, outputs, programs)
            metrics, latent_metrics = jax.tree_map(np.array, all_metrics)
            eval_metrics.append(metrics)
            latent_eval_metrics.append(latent_metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary.update(
            jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums))

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 50, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions, latent_predictions = [], [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache, lp_cache = p_init_cache(inputs, outputs, programs)
                predicted, latent_predicted = p_pred_step(
                    state, inputs, outputs, cache, lp_cache, beam_size)
                predicted, latent_predicted = map(
                    tohost, (predicted, latent_predicted))
                inputs, outputs, programs = map(tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_idx, p_score = eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')
                    latent_predictions.append(' '.join(
                        list(np.array(latent_predicted[i,
                                                       p_idx]).astype(str))))

            all_pred_acc, all_pred_denominator = per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n'
                        f'latent_predicted: {latent_predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()
Beispiel #5
0
 def replicate(self):
     return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
        beta1=training_args.adam_beta1,
        beta2=training_args.adam_beta2,
    ).create(model.params)

    # Create learning rate scheduler
    # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
    lr_scheduler_fn = create_learning_rate_scheduler(
        base_learning_rate=training_args.learning_rate,
        warmup_steps=max(training_args.warmup_steps, 1))

    # Create parallel version of the training and evaluation steps
    p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the optimizer on each device
    optimizer = jax_utils.replicate(optimizer)

    # Store some constant
    nb_epochs = int(training_args.num_train_epochs)
    batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()

    epochs = tqdm(range(nb_epochs),
                  desc=f"Epoch ... (1/{nb_epochs})",
                  position=0)
    for epoch in epochs:

        # ======================== Training ================================
        # Create sampling rng
Beispiel #7
0
def main(config, output_dir):
    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d dev per host (%d dev total), that is a %d per-device batch size.',
        batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))
    logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset, config.val_split, config.pp_eval,
                       config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))

    # Specify Gaussian process layer configs.
    use_gp_layer = config.get('use_gp_layer', True)
    gp_config = config.get('gp_layer', {})
    gp_layer_kwargs = get_gp_kwargs(gp_config)

    # Process ViT backbone model configs.
    vit_kwargs = config.get('model')

    model = ub.models.vision_transformer_gp(num_classes=config.num_classes,
                                            use_gp_layer=use_gp_layer,
                                            vit_kwargs=vit_kwargs,
                                            gp_layer_kwargs=gp_layer_kwargs)

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        variables = model.init(rng, dummy_input, train=False)
        # Split model parameters into trainable and untrainable collections.
        states, params = variables.pop('params')
        del variables

        # Set bias in the head to a low value, such that loss is small initially.
        params = flax.core.unfreeze(params)
        if use_gp_layer:
            # Modify the head parameter in the GP head.
            params['head']['output_layer']['bias'] = jnp.full_like(
                params['head']['output_layer']['bias'],
                config.get('init_head_bias', 0))
        else:
            params['head']['bias'] = jnp.full_like(
                params['head']['bias'], config.get('init_head_bias', 0))

        return params, states

    rng, rng_init = jax.random.split(rng)
    params_cpu, states_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, states, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        variable_dict = {'params': flax.core.freeze(params), **states}
        logits, out = model.apply(variable_dict,
                                  images,
                                  train=False,
                                  mean_field_factor=gp_config.get(
                                      'mean_field_factor', -1.))

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :config.num_classes],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, states, images, labels, mask):
        variable_dict = {'params': flax.core.freeze(params), **states}
        logits, out = model.apply(variable_dict,
                                  images,
                                  train=False,
                                  mean_field_factor=gp_config.get(
                                      'mean_field_factor', -1.))

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask, states):
        variable_dict = {'params': flax.core.freeze(params), **states}
        _, outputs = model.apply(variable_dict,
                                 images,
                                 train=False,
                                 mean_field_factor=gp_config.get(
                                     'mean_field_factor', -1.))
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
        """Update step."""
        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, states, images, labels):
            # Specify mutable collection to update untrainable GP parameters.
            variable_dict = {'params': flax.core.freeze(params), **states}
            model_results, updated_states = model.apply(
                variable_dict,
                images,
                train=True,
                rngs={'dropout': rng_model_local},
                mutable=list(states.keys()),
                mean_field_factor=gp_config.get('mean_field_factor', -1.))

            logits, _ = model_results
            loss = getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)
            return loss, updated_states

        # Performs exact covariance update (i.e., reset precision matrix resetting
        # at begining of new epoch) if covmat_momentum is a null value.
        if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0:
            # Resets precision matrix to Identity * ridge_penalty if at the begining
            # of a new epoch. This should be done before accumulate gradient.
            ridge_penalty = gp_config.get('ridge_penalty', 1.)
            prec_mat_old = states['laplace_covariance']['head'][
                'covmat_layer']['precision_matrix']
            prec_mat_new = (
                (1. - reset_covmat) * prec_mat_old +
                reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

            states = flax.core.unfreeze(states)
            states['laplace_covariance']['head']['covmat_layer'][
                'precision_matrix'] = prec_mat_new
            states = flax.core.freeze(states)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        (l, s), g = train_utils.accumulate_gradient_with_states(
            jax.value_and_grad(loss_fn, has_aux=True), opt.target, states,
            images, labels, config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        do_grad_clip = config.get('grad_clip_norm', -1.) > 0.
        if config.get('grad_accum_steps', 1) == 1 or do_grad_clip:
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if do_grad_clip:
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))
        measurements['reset_covmat'] = reset_covmat

        return opt, s, l, rng, measurements

    default_reinit_params = ('head/output_layer/kernel',
                             'head/output_layer/bias', 'head/kernel',
                             'head/bias')
    rng, train_loop_rngs = jax.random.split(rng)
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=states_cpu,
        default_reinit_params=default_reinit_params,
        config=config)
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    states_cpu = checkpoint_data.fixed_model_states
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
    reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
    reset_covmat_iter = train_utils.prefetch_scalar(
        map(reset_covmat_fn, range(first_step, total_steps)),
        nprefetch=config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)
    states_repl = flax_utils.replicate(states_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.

    # Makes sure log_eval_steps is same as steps_per_epoch. This is because
    # the precision matrix needs to be updated fully (at the end of each epoch)
    # when eval takes place.
    log_eval_steps = steps_per_epoch
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl, reset_covmat_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter,
            reset_covmat_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            # TODO(jereliu): Expand to allow precision matrix resetting.
            (opt_repl, states_repl, loss_value, train_loop_rngs,
             extra_measurements) = update_fn(opt_repl,
                                             states_repl,
                                             lr_repl,
                                             reset_covmat_repl,
                                             train_batch['image'],
                                             train_batch['labels'],
                                             rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if train_utils.itstime(step,
                               config.get('checkpoint_steps'),
                               total_steps,
                               process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            # For GP layer, we will also do the same for untrainable parameters
            # (`states`). This is ok since `random features` are frozen throughout
            # pre-training, and `precision matrix` is a finetuning-specific parameters
            # that will be re-learned in the finetuning task.
            opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
            states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                optimizer=opt_cpu,
                fixed_model_states=states_cpu,
                train_loop_rngs=train_loop_rngs,
                accumulated_train_time=accumulated_train_time)
            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if train_utils.itstime(step,
                               config.log_training_steps,
                               total_steps,
                               process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, val_ds in val_ds_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                # TODO(jereliu): Extend to support soft multi-class probabilities.
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                val_iter = input_utils.start_input_pipeline(
                    val_ds, config.get('prefetch_to_device', 1))
                ncorrect, loss, nseen = 0, 0, 0
                for batch in val_iter:
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    states_repl,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, states_repl,
                                          batch['image'], batch['labels'],
                                          batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name]
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # There are two entries in the ood_ds dict (in-dist, ood), and that this
            # section computes metrics using both pieces. This is in contrast to
            # normal validation eval above where we eval metrics separately for each
            # val split in val_ds.
            if ood_ds and config.ood_methods:

                def make_sngp_eval_fn(states):
                    def sngp_eval_fn(params, images, labels, mask):
                        return evaluation_fn(params=params,
                                             states=states,
                                             images=images,
                                             labels=labels,
                                             mask=mask)

                    return sngp_eval_fn

                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds,
                    ood_ds_names,
                    config.ood_methods,
                    make_sngp_eval_fn(states_repl),
                    opt_repl.target,
                    n_prefetch=config.get('prefetch_to_device', 1))
                writer.write_scalars(step, ood_measurements)

            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target,
                    datasets=config.fewshot.datasets,
                    states=states_repl)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Beispiel #8
0
def create_synchronized_rng_seed():
    rng_seed = np.int64(struct.unpack('q', os.urandom(8))[0])
    rng_seed = _sum_seeds_pmapped(jax_utils.replicate(rng_seed))
    rng_seed = np.sum(rng_seed)
    return rng_seed
Beispiel #9
0
    def train_and_evaluate(self, workdir):
        """Runs a training and evaluation loop.

    Args:
      workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

        tf.io.gfile.makedirs(workdir)
        config = self.config
        substeps = config.training.substeps

        # Learning rate schedule.
        num_train_steps = config.training.num_train_steps
        logging.info('num_train_steps=%d', num_train_steps)

        # Get train state
        state = self._train_state

        # Set up checkpointing of the model and the input pipeline.
        checkpoint_dir = os.path.join(workdir, 'checkpoints')
        ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
        state = ckpt.restore_or_initialize(state)
        initial_step = int(state.step)

        # Distribute training.
        state = flax_utils.replicate(state)

        writer = metric_writers.create_default_writer(
            workdir, just_logging=jax.process_index() > 0)
        if initial_step == 0:
            writer.write_hparams(dict(config))

        logging.info('Starting training loop at step %d.', initial_step)
        hooks = []
        report_progress = periodic_actions.ReportProgress(
            num_train_steps=num_train_steps, writer=writer)
        if jax.process_index() == 0:
            hooks += [
                report_progress,
                periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
            ]
        step = initial_step
        with metric_writers.ensure_flushes(writer):
            while step < num_train_steps:
                # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
                # devices.
                is_last_step = step + substeps >= num_train_steps

                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    inputs = jax.tree_map(np.asarray, next(self._train_iter))
                    state, outputs = self._update_func(state, inputs)

                # Quick indication that training is happening.
                logging.log_first_n(logging.INFO, 'Finished training step %d.',
                                    5, step)
                for h in hooks:
                    h(step)

                new_step = int(state.step[0])
                assert new_step == step + substeps
                step = new_step

                is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step
                if step % config.logs.log_loss_every_steps == 0 and not is_eval:

                    def avg_over_substeps(x):
                        assert x.shape[0] == substeps
                        return float(x.mean(axis=0))

                    # Extract scalars and images.
                    outputs = flax_utils.unreplicate(outputs)
                    outputs = jax.tree_map(avg_over_substeps, outputs)
                    scalars = outputs['scalars']
                    writer.write_scalars(step, scalars)

                if is_eval:
                    with report_progress.timed('eval_full'):
                        outputs = self._eval_epoch(params=state.ema_params)
                        outputs = flax_utils.unreplicate(outputs)
                        scalars = outputs['scalars']
                        writer.write_scalars(step, scalars)

                if step % config.logs.checkpoint_every_steps == 0 or is_last_step:
                    with report_progress.timed('checkpoint'):
                        ckpt.save(flax_utils.unreplicate(state))

        logging.info('Finishing training at step %d', num_train_steps)
Beispiel #10
0
def eval_checkpoints(
    checkpoint_dir,
    hps,
    rng,
    eval_num_batches,
    model_cls,
    dataset_builder,
    dataset_meta_data,
    hessian_eval_config,
    min_global_step=None,
    max_global_step=None,
    use_deprecated_checkpointing=True,
):
  """Evaluate the Hessian of the given checkpoints.

  Iterates over all checkpoints in the specified directory, loads the checkpoint
  then evaluates the Hessian on the given checkpoint. A list of dicts will be
  saved to cns at checkpoint_dir/hessian_eval_config['name'].

  Args:
    checkpoint_dir: Directory of checkpoints to load.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_num_batches: (int) The batch size used for evaluating on
      validation, and test sets. Set to None to evaluate on the whole test set.
    model_cls: One of the model classes (not an instance) defined in model_lib.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    dataset_meta_data: dict of meta_data about the dataset.
    hessian_eval_config: a dict specifying the configuration of the Hessian
      eval.
    min_global_step: Lower bound on what steps to filter checkpoints. Set to
      None to evaluate all checkpoints in the directory.
    max_global_step: Upper bound on what steps to filter checkpoints.
    use_deprecated_checkpointing: Whether to use deprecated checkpointing.
  """
  rng, init_rng = jax.random.split(rng)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, data_rng = jax.random.split(rng)

  initializer = initializers.get_initializer('noop')

  loss_name = 'cross_entropy'
  metrics_name = 'classification_metrics'
  model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

  # Maybe run the initializer.
  flax_module, batch_stats = trainer.initialize(model.flax_module_def,
                                                initializer, model.loss_fn,
                                                hps.input_shape,
                                                hps.output_shape, hps, init_rng,
                                                None)

  # Fold in a the unreplicated batch_stats and rng into the loss used by
  # hessian eval.
  def batch_loss(module, batch_rng):
    batch, rng = batch_rng
    return model.training_cost(module, batch_stats, batch, rng)[0]
  batch_stats = jax_utils.replicate(batch_stats)

  if jax.host_id() == 0:
    utils.log_pytree_shape_and_statistics(flax_module.params)
    logging.info('train_size: %d,', hps.train_size)
    logging.info(hps)
    # Save the hessian computation hps to the experiment directory
    exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    if not gfile.exists(exp_dir):
      gfile.mkdir(exp_dir)
    if min_global_step == 0:
      hparams_fname = os.path.join(exp_dir, 'hparams.json')
      with gfile.GFile(hparams_fname, 'w') as f:
        f.write(hps.to_json())
      config_fname = os.path.join(exp_dir, 'hconfig.json')
      with gfile.GFile(config_fname, 'w') as f:
        f.write(json.dumps(hessian_eval_config))

  optimizer = trainer.get_optimizer(hps).create(flax_module)
  optimizer = jax_utils.replicate(optimizer)
  data_rng = jax.random.fold_in(data_rng, 0)

  assert hps.batch_size % (jax.device_count()) == 0
  dataset = dataset_builder(
      data_rng,
      hps.batch_size,
      eval_batch_size=hps.batch_size,  # eval iterators not used.
      hps=hps,
  )

  # pmap functions for the training loop
  evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')

  if jax.host_id() == 0:
    logging.info('Starting eval!')
    logging.info('Number of hosts: %d', jax.host_count())

  hessian_evaluator = hessian_eval.CurvatureEvaluator(
      optimizer.target,
      hessian_eval_config,
      dataset,
      batch_loss)
  if min_global_step is None:
    suffix = ''
  else:
    suffix = '{}_{}'.format(min_global_step, max_global_step)
  pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'],
                             suffix)
  logger = utils.MetricLogger(pytree_path=pytree_path)
  for checkpoint_path, step in iterate_checkpoints(checkpoint_dir,
                                                   min_global_step,
                                                   max_global_step):
    ckpt = checkpoint.load_checkpoint(
        checkpoint_path,
        target=(optimizer, batch_stats),
        use_deprecated_checkpointing=use_deprecated_checkpointing)
    results = trainer.restore_checkpoint(
        ckpt,
        (optimizer, batch_stats),
        use_deprecated_checkpointing=use_deprecated_checkpointing)
    optimizer, batch_stats = results[0]
    # pylint: disable=protected-access
    batch_stats = trainer._maybe_sync_batchnorm_stats(batch_stats)
    # pylint: enable=protected-access
    report, _ = trainer.eval_metrics(optimizer.target, batch_stats, dataset,
                                     eval_num_batches, eval_num_batches,
                                     evaluate_batch_pmapped)
    if jax.host_id() == 0:
      logging.info('Global Step: %d', step)
      logging.info(report)
    row = {}
    grads, updates = [], []
    hess_evecs, cov_evecs = [], []
    stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum(
        optimizer.target, step)
    row.update(stats)
    if hessian_eval_config[
        'compute_stats'] or hessian_eval_config['compute_interps']:
      grads, updates = hessian_evaluator.compute_dirs(optimizer)
    row.update(hessian_evaluator.evaluate_stats(optimizer.target, grads,
                                                updates, hess_evecs,
                                                cov_evecs, step))
    row.update(hessian_evaluator.compute_interpolations(optimizer.target, grads,
                                                        updates, hess_evecs,
                                                        cov_evecs, step))
    if jax.host_id() == 0:
      logger.append_pytree(row)
Beispiel #11
0
def compute_is_scores(filename):
  """Compute IS scores for training data."""

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  if FLAGS.jax_backend_target:
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.batch_size % n_devices:
    raise ValueError('Batch size must be divisible by the number of devices')

  vocab_path = FLAGS.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  print('Loading data')
  logging.info('Initializing dataset.')
  train_ds, encoder = input_pipeline.get_wmt_is_datasets(
      n_devices=n_devices,
      dataset_name=FLAGS.dataset_name,
      shard_idx=jax.host_id(),
      shard_count=jax.host_count(),
      data_dir=FLAGS.data_dir,
      vocab_path=vocab_path,
      target_vocab_size=FLAGS.vocab_size,
      batch_size=FLAGS.batch_size,
      max_length=FLAGS.max_target_length,
      paracrawl_size=FLAGS.paracrawl_size)
  print('Datasets created')

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
  print('data iterators created')

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  eval_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=FLAGS.share_embeddings,
      logits_via_embedding=FLAGS.logits_via_embedding,
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
      emb_dim=FLAGS.emb_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.qkv_dim,
      mlp_dim=FLAGS.mlp_dim,
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      deterministic=True,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))

  start_step = 0
  rng = jax.random.PRNGKey(FLAGS.random_seed)
  rng, init_rng = jax.random.split(rng)
  # It's possible that is supposed to be per device batch size
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      FLAGS.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if FLAGS.restore_checkpoints:
    logging.info('Restoring checkpoint.')
    # If we have a pretrained model, use that. Else, just continue where leftoff
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
    # When loading a checkpoint trained with adapters (ie. frozen weights)
    # restoring from the base optimizer fails. We catch this error and create
    # the optimizer with frozen weights.
    try:
      optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
      # Grab last step.
      start_step = int(optimizer.state.step)
    except ValueError:
      adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
      optimizer = optimizer_def.create(optimizer.target, focus=adapter)
      optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
      start_step = optimizer.state[0].step

  else:
    raise RuntimeError('Must restore checkpoint for IS')

  if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer):
    adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
    optimizer = optimizer_def.create(optimizer.target, focus=adapter)
  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  p_eval_step = jax.pmap(
      functools.partial(
          eval_for_is_step,
          config=eval_config),
      axis_name='batch')

  logging.info('Start scoring loop.')
  metrics_all = []
  t_loop_start = time.time()

  # Eval Metrics
  logging.info('Gathering evaluation metrics.')
  t_eval_start = time.time()
  save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'
  length_fp = tf.io.gfile.GFile(save_file, 'w')
  lengths_writer = csv.writer(length_fp)

  save_file = FLAGS.is_save_path + '/' + filename + '.txt'
  with tf.io.gfile.GFile(save_file, 'w') as fp:
    writer = csv.writer(fp)

    for batch_idx, eval_batch in enumerate(train_iter):
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
      cur_pred_batch_size = eval_batch['inputs'].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        eval_batch = jax.tree_map(
            lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
      eval_batch = common_utils.shard(eval_batch)
      losses, lengths = p_eval_step(optimizer.target, eval_batch)
      if jax.host_id() == 0:
        losses = common.tohost(losses)
        lengths = common.tohost(lengths)
        if cur_pred_batch_size % n_devices:
          writer.writerow(losses[:cur_pred_batch_size])
          lengths_writer.writerow(lengths[:cur_pred_batch_size])
        else:
          writer.writerow(losses)
          lengths_writer.writerow(lengths)

      if batch_idx % 500 == 0:
        print('Batch', batch_idx)
        print(time.time() - t_loop_start)
  length_fp.close()
def inference_time(config: ml_collections.ConfigDict, workdir: str):
  """Runs a number of steps and measures inference time."""

  assert config.batch, f'Expected --config.batch={config.batch} > 0'
  assert config.num_classes, (
      f'Expected --config.num_classes={config.num_classes} > 0')
  assert config.image_size, (
      f'Expected --config.image_size={config.image_size} > 0')

  # Build VisionTransformer architecture
  model_config = config_lib.MODEL_CONFIGS[config.model_name]
  model = models.VisionTransformer(
      num_classes=config.num_classes, **model_config)

  # Make sure initial model parameters (before replication) are on CPU only.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    return model.init(
        rng,
        # Discard the "num_local_devices" dimension for initialization.
        inputs=jnp.ones([1, config.image_size, config.image_size, 3],
                        jnp.float32),
        train=False)

  variables = init(jax.random.PRNGKey(0))

  params_repl = flax_utils.replicate(variables['params'])

  # pmap replicates the models over all TPUs/GPUs
  vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False))
  images = jnp.ones([
      jax.local_device_count(), config.batch // jax.local_device_count(),
      config.image_size, config.image_size, 3
  ], jnp.float32)

  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())

  logging.info('Starting training loop; initial compile can take a while...')
  logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
  logits.block_until_ready()
  logging.info('Done.')

  logging.info('Going to run %d inferences WITHOUT measuring...',
               config.initial_steps)
  for _ in range(config.initial_steps):
    logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
    logits.block_until_ready()

  logging.info('Going to run %d inferences measuring...', config.steps)
  times = []
  for _ in range(config.initial_steps):
    t0 = time.time()
    logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
    logits.block_until_ready()
    times.append(time.time() - t0)
  logging.info('times=%s', times)
  imgs_sec_core = config.batch / jax.local_device_count() / np.array(times)
  logging.info('imgs_sec_core_min=%f', imgs_sec_core.min())
  logging.info('imgs_sec_core_max=%f', imgs_sec_core.max())
  logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean())
  logging.info('imgs_sec_core_std=%f', imgs_sec_core.std())
  writer.write_scalars(
      0,
      dict(
          imgs_sec_core_min=imgs_sec_core.min(),
          imgs_sec_core_max=imgs_sec_core.max(),
          imgs_sec_core_mean=imgs_sec_core.mean(),
          imgs_sec_core_std=imgs_sec_core.std(),
      ))
Beispiel #13
0
def predict_and_evaluate(config, workdir, ckpt_path=None):
    """Runs a testing and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    ckpt_path: The checkpoint to evaluate. If not specified, use the latest
      checkpoint.
  """
    logging.info('Starting testing at %s', workdir)
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Initialize model.
    inputs = train_utils.get_init_inputs(test_ds[0])
    rng, model_rng = jax.random.split(rng)
    predict_config = models.TransformerConfig(**config.model.to_dict())
    predict_config = predict_config.replace(decode=True)
    model = models.Model(predict_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)

    logging.info('Testing and evaluating checkpoint %s', ckpt_path)
    try:
        state = ckpt.restore(state, ckpt_path)
    except FileNotFoundError:
        state = ckpt.restore_or_initialize(state)
    step = int(state.step)

    p_pred_step = jax.pmap(functools.partial(predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(3, ))
    p_init_cache = jax.pmap(functools.partial(init_cache,
                                              config=predict_config),
                            axis_name='batch')

    # Distribute testing.
    state = flax_utils.replicate(state)
    with metric_writers.ensure_flushes(writer):
        test_metrics = {}
        for ds, split in zip(test_ds, config.dataset.test_splits):
            ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache,
                                                    state, ds, config, split,
                                                    workdir,
                                                    config.num_test_steps)
            ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()}
            test_metrics.update(ds_metrics)
        writer.write_scalars(step, test_metrics)
def main(argv):
  del argv

  config = FLAGS.config
  workdir = FLAGS.workdir

  logging.info("Workdir: %s", workdir)

  save_checkpoint_path = None
  if config.get("checkpoint_steps"):
    tf.io.gfile.makedirs(workdir)
    save_checkpoint_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  def write_note(note):
    if jax.host_id() == 0:
      logging.info("NOTE: %s", note)
  write_note("Initializing...")

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get("keep_checkpoint_steps"):
    assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`."
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be"
        f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`")

  batch_size = config.batch_size
  batch_size_eval = config.get("batch_size_eval", batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must "
                     f"be divisible by device number ({jax.device_count()})")

  local_batch_size = batch_size // jax.host_count()
  local_batch_size_eval = batch_size_eval // jax.host_count()
  logging.info(
      "Global batch size %d on %d hosts results in %d local batch size. "
      "With %d dev per host (%d dev total), that's a %d per-device batch size.",
      batch_size, jax.host_count(), local_batch_size,
      jax.local_device_count(), jax.device_count(),
      local_batch_size // jax.local_device_count())

  write_note("Initializing train dataset...")
  train_ds = input_pipeline.get_data(
      dataset=config.dataset,
      split=config.train_split,
      data_dir=fillin(config.get("dataset_dir")),
      batch_size=local_batch_size,
      preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch=config.get("prefetch_to_host", 2),
      cache=False)

  # Start prefetching already.
  train_iter = u.start_input_pipeline(
      train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size)
  # We always pad to local_batch_size_eval even when less would be enough in
  # order to minimize memory fragmentation.

  write_note("Initializing val dataset(s)...")
  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_pipeline.get_num_examples(
        dataset, split, data_dir=fillin(data_dir))
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info("Running validation for %d steps for %s, %s", val_steps,
                 dataset, split)

    val_it = input_pipeline.get_data(
        dataset=dataset,
        split=split,
        data_dir=fillin(data_dir),
        batch_size=local_batch_size_eval,
        preprocess_fn=pp_builder.get_preprocess_fn(pp_eval),
        cache=config.get("val_cache", "batched"),
        repeat_after_batching=True,
        prefetch=0,  # Save memory since we cache.
        drop_remainder=False,
        shuffle_files=False)
    val_it = u.start_input_pipeline(
        val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval)

    return (val_it, val_steps)

  if isinstance(config.val_split, str):
    val_ds = {"val": _get_val_split(config.dataset, config.val_split,
                                    config.pp_eval, config.get("dataset_dir"))}
  else:
    val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split}

  ntrain_img = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=fillin(config.get("dataset_dir")))
  steps_per_epoch = ntrain_img / batch_size

  if config.get("num_epochs"):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get("total_steps"), "Set either num_epochs or total_steps"
  else:
    total_steps = config.total_steps

  logging.info(
      "Running for %d steps, that means %f epochs and %f steps per epoch",
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)
  mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch)

  write_note(f"Initializing {config.model_name} model...")
  model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}")
  model = model_mod.Model(
      num_classes=config.num_classes, **config.get("model", {}))

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @partial(jax.jit, backend="cpu")
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]

    # Set bias in the head to a low value, such that loss is small initially.
    params["head"]["bias"] = jnp.full_like(
        params["head"]["bias"], config.get("init_head_bias", 0))

    return params

  rng, rng_init = jax.random.split(rng)
  params_cpu = init(rng_init)

  if jax.host_id() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    mw.measure("num_params", num_params)

  @partial(jax.pmap, axis_name="batch")
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    logits, _ = model.apply({"params": flax.core.freeze(params)}, images)

    losses = getattr(u, config.get("loss", "sigmoid_xent"))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name="batch")

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch")
    n = jax.lax.psum(mask, axis_name="batch")
    return ncorrect, loss, n

  # Setup function for computing representation.
  @partial(jax.pmap, axis_name="batch")
  def representation_fn(params, images, labels, mask):
    _, outputs = model.apply({"params": flax.core.freeze(params)}, images)
    representation = outputs[config.fewshot.representation_layer]
    representation = jax.lax.all_gather(representation, "batch")
    labels = jax.lax.all_gather(labels, "batch")
    mask = jax.lax.all_gather(mask, "batch")
    return representation, labels, mask

  # Load the optimizer either from our folder or from flax.
  opt_name = config.get("optim_name", "momentum_hp")
  write_note(f"Initializing {opt_name} optimizer...")
  try:
    opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}")
    opt_def = opt_mod.Optimizer(**config.get("optim", {}))
  except ModuleNotFoundError:
    opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  @partial(jax.pmap, axis_name="batch", donate_argnums=(0,))
  def update_fn(opt, lr, images, labels, rng):
    """Update step."""

    measurements = {}

    if config.get("mixup") and config.mixup.p:
      rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      logits, _ = model.apply(
          {"params": flax.core.freeze(params)}, images,
          train=True, rngs={"dropout": rng_model_local})
      return getattr(u, config.get("loss", "sigmoid_xent"))(
          logits=logits, labels=labels)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target,
                                 images, labels,
                                 config.get("grad_accum_steps"))
    l, g = jax.lax.pmean((l, g), axis_name="batch")

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"):
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements["l2_grads"] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if config.get("grad_clip_norm"):
      g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)

    decay_rules = config.get("weight_decay", []) or []
    if isinstance(decay_rules, numbers.Number):
      decay_rules = [(".*kernel.*", decay_rules)]
    sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr
    def decay_fn(v, wd):
      return (1.0 - sched_m * wd) * v
    opt = opt.replace(target=u.tree_map_with_regex(
        decay_fn, opt.target, decay_rules, name="weight decay"))

    params, _ = jax.tree_flatten(opt.target)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

    return opt, l, rng, measurements

  # Other things besides optimizer state to be stored.
  checkpoint_extra = dict(accum_train_time=0.0)

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize model from something, e,g, start a fine-tuning job.
  # 4. Train from scratch.
  resume_checkpoint_path = None
  if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
    resume_checkpoint_path = save_checkpoint_path
  elif config.get("resume"):
    resume_checkpoint_path = fillin(config.resume)
  if resume_checkpoint_path:
    write_note("Resume training from checkpoint...")
    checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
    _, checkpoint_tree = jax.tree_flatten(checkpoint)
    loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
  elif config.get("model_init"):
    write_note(f"Initialize model from {config.model_init}...")
    loaded = model_mod.load(params_cpu, config.model_init, config.get("model"))
    opt_cpu = opt_cpu.replace(target=loaded)
    if jax.host_id() == 0:
      logging.info("Restored parameter overview:")
      parameter_overview.log_parameter_overview(loaded)

  write_note("Kicking off misc stuff...")
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  chrono = u.Chrono(first_step, total_steps, batch_size,
                    checkpoint_extra["accum_train_time"])
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=workdir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = u.create_learning_rate_schedule(
      batch_size, total_steps, steps_per_epoch, **config.get("lr", {}))
  lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
                              config.get("prefetch_to_device", 1))

  write_note(f"Replicating...\n{chrono.note}")
  opt_repl = flax_utils.replicate(opt_cpu)

  write_note(f"Initializing few-shotters...\n{chrono.note}")
  if "fewshot" in config:
    fewshotter = fewshot.FewShotEvaluator(
        representation_fn, config.fewshot,
        config.fewshot.get("batch_size") or batch_size_eval)

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax_utils.replicate(rng_loop)
  checkpoint_writer = None

  write_note(f"First step compilations...\n{chrono.note}")
  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter):
    mw.step_start(step)

    with jax.profiler.TraceContext("train_step", step_num=step, _r=1):
      opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
          opt_repl,
          lr_repl,
          train_batch["image"],
          train_batch["labels"],
          rng=rngs_loop)

    if jax.host_id() == 0:
      profiler(step)

    # Checkpoint saving
    if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0):
      chrono.pause()
      u.checkpointing_timeout(checkpoint_writer,
                              config.get("checkpoint_timeout", 1))
      checkpoint_extra["accum_train_time"] = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps):
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
      checkpoint_writer = pool.apply_async(
          u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if u.itstime(step, config.log_training_steps, total_steps, host=0):
      mw.measure("learning_rate", lr_repl[0])
      mw.measure("training_loss", loss_value[0])
      for name, value in extra_measurements.items():
        mw.measure(name, value[0])
      chrono.tick(step, mw.measure, write_note)

    # Report validation performance
    if u.itstime(step, config.log_eval_steps, total_steps):
      chrono.pause()
      for val_name, (val_iter, val_steps) in val_ds.items():
        ncorrect, loss, nseen = 0, 0, 0
        for _, batch in zip(range(val_steps), val_iter):
          batch_ncorrect, batch_losses, batch_n = evaluation_fn(
              opt_repl.target, batch["image"], batch["labels"], batch["mask"])
          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.
          ncorrect += np.sum(np.array(batch_ncorrect[0]))
          loss += np.sum(np.array(batch_losses[0]))
          nseen += np.sum(np.array(batch_n[0]))
        mw.measure(f"{val_name}_prec@1", ncorrect / nseen)
        mw.measure(f"{val_name}_loss", loss / nseen)
      chrono.resume()

    if "fewshot" in config:
      # Compute few-shot on-the-fly evaluation.
      if u.itstime(step, config.fewshot.log_steps, total_steps):
        chrono.pause()
        write_note(f"Few-shot evaluation...\n{chrono.note}")
        r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets)
        fewshotter.walk_results(mw.measure, *r)
        chrono.resume()
    mw.step_end()

  write_note(f"Done!\n{chrono.note}")
  pool.close()
  pool.join()
  mw.close()
Beispiel #15
0
    def test_train(self):
        tf.enable_v2_behavior()

        tf.random.set_seed(0)
        np.random.seed(0)
        random.seed(0)

        dataset_filepattern = os.path.join(
            os.path.dirname(__file__),
            'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*'
        )

        print('dataset_filepattern = {}'.format(dataset_filepattern))

        batch_size = 4
        num_strings_per_task = 4
        max_characters = 10
        max_program_length = 15

        # Build token tables.
        id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
        char_id_table = {char: id for id, char in id_char_table.items()}
        _, token_id_table = dsl_tokens.build_token_tables()
        io_vocab_size = len(char_id_table) + 1  # For padding.
        program_vocab_size = len(token_id_table) + 1

        bos_token = token_id_table[dsl.BOS]

        # Load dataset.
        dataset = input_pipeline.create_dataset_from_tf_record(
            dataset_filepattern, token_id_table, char_id_table)
        dataset = dataset.padded_batch(batch_size,
                                       padded_shapes=((num_strings_per_task,
                                                       max_characters),
                                                      (num_strings_per_task,
                                                       max_characters),
                                                      (max_program_length, )),
                                       drop_remainder=True)
        dataset_iter = dataset.repeat().as_numpy_iterator()

        train_config = models.TransformerConfig(
            vocab_size=io_vocab_size,
            output_vocab_size=program_vocab_size,
            shift=True,
            emb_dim=32,
            num_heads=4,
            num_layers=2,
            qkv_dim=32,
            mlp_dim=32,
            max_len=max(max_characters, max_program_length),
            deterministic=False,
            decode=False,
            bos_token=bos_token)
        eval_config = train_config.replace(deterministic=True)

        rng = jax.random.PRNGKey(0)
        rng, init_rng = jax.random.split(rng)

        m = models.ProgramTransformer(eval_config)
        initial_variables = jax.jit(m.init)(
            init_rng,
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, max_program_length), jnp.float32))

        optimizer_def = optim.Adam(1e-2,
                                   beta1=0.9,
                                   beta2=0.98,
                                   eps=1e-9,
                                   weight_decay=0.1)
        optimizer = optimizer_def.create(initial_variables['params'])

        del initial_variables  # Don't keep a copy of the initial model.

        optimizer = jax_utils.replicate(optimizer)

        learning_rate_fn = train_lib.create_learning_rate_scheduler(
            base_learning_rate=1e-2)
        p_train_step = jax.pmap(functools.partial(
            train_lib.train_step,
            learning_rate_fn=learning_rate_fn,
            config=train_config),
                                axis_name='batch')
        p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                                 config=eval_config),
                               axis_name='batch')

        # Training loop.
        start_step = 0
        rngs = jax.random.split(rng, jax.local_device_count())
        del rng

        for _ in range(start_step, 1000):
            inputs, outputs, programs = common_utils.shard(next(dataset_iter))
            optimizer, _, rngs = p_train_step(optimizer,
                                              inputs,
                                              outputs,
                                              programs,
                                              train_rng=rngs)

        # Evaluation.
        eval_metrics = []
        for batches in dataset.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            self.assertGreater(eval_summary['accuracy'], 0.1)
Beispiel #16
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    configure_logger(model_args, training_args)

    # Downloading and loading a dataset from the hub.
    datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)

    if "validation" not in datasets.keys():
        # make sure only "validation" and "train" keys remain"
        datasets = DatasetDict()
        datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
            cache_dir=model_args.cache_dir,
        )
        datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
            cache_dir=model_args.cache_dir,
        )
    else:
        # make sure only "validation" and "train" keys remain"
        datasets = DatasetDict()
        datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split="validation",
            cache_dir=model_args.cache_dir,
        )
        datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"{data_args.train_split_name}",
            cache_dir=model_args.cache_dir,
        )

    # only normalized-inputs-training is supported
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
    )

    def prepare_dataset(batch):
        # check that all files have the correct sampling rate
        batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
        return batch

    # load audio files into numpy arrays
    vectorized_datasets = datasets.map(
        prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
    )

    # filter audio files that are too long
    vectorized_datasets = vectorized_datasets.filter(
        lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
    )

    def normalize(batch):
        return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)

    # normalize and transform to `BatchFeatures`
    vectorized_datasets = vectorized_datasets.map(
        normalize,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        remove_columns=vectorized_datasets["train"].column_names,
    )

    # pretraining is only supported for "newer" stable layer norm architecture
    # apply_spec_augment has to be True, mask_feature_prob has to be 0.0
    config = Wav2Vec2Config.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
        raise ValueError(
            "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
        )

    model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))

    # Activate gradient checkpointing if needed
    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
        model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    gumbel_rngs = jax.random.split(rng, jax.local_device_count())

    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()

    num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
    )
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state and define training hyper-parameters
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
    num_negatives = model.config.num_negatives
    contrastive_logits_temperature = model.config.contrastive_logits_temperature
    num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
    diversity_loss_weight = model.config.diversity_loss_weight

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng, gumbel_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)

        def loss_fn(params):
            negative_indices = batch.pop("sampled_negative_indices")

            gumbel_temperature = jnp.clip(
                model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
                a_min=model_args.min_gumbel_temperature,
            )

            outputs = state.apply_fn(
                **batch,
                gumbel_temperature=gumbel_temperature,
                params=params,
                dropout_rng=dropout_rng,
                gumbel_rng=gumbel_rng,
                train=True,
            )

            contrastive_loss = compute_contrastive_loss(
                outputs.projected_quantized_states,
                outputs.projected_states,
                negative_indices,
                batch["mask_time_indices"],
                contrastive_logits_temperature,
                num_negatives,
            )

            diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
            loss = contrastive_loss + diversity_loss_weight * diversity_loss

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
        )

        return new_state, metrics, new_dropout_rng, new_gumbel_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))

    # Define eval fn
    def eval_step(params, batch):
        negative_indices = batch.pop("sampled_negative_indices")

        outputs = model(**batch, params=params, train=False)

        contrastive_loss = compute_contrastive_loss(
            outputs.projected_quantized_states,
            outputs.projected_states,
            negative_indices,
            batch["mask_time_indices"],
            contrastive_logits_temperature,
            num_negatives,
        )

        diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
        loss = contrastive_loss + diversity_loss_weight * diversity_loss

        # summarize metrics
        metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(vectorized_datasets["train"])
        train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
            model_inputs = data_collator(samples)
            model_inputs = shard(model_inputs.data)

            # Model forward
            state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
                state, model_inputs, dropout_rngs, gumbel_rngs
            )
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
                )

                train_metrics = []

        # ======================== Evaluating ==============================
        num_eval_samples = len(vectorized_datasets["validation"])
        eval_samples_idx = jnp.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)

        eval_metrics = []
        for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx]
            model_inputs = data_collator(samples)

            # Model forward
            model_inputs = shard(model_inputs.data)
            metrics = p_eval_step(state.params, model_inputs)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

        # Update progress bar
        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
        )

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
            write_eval_metric(summary_writer, eval_metrics, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
Beispiel #17
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    batch_size = config.batch_size
    n_devices = jax.device_count()
    if jax.process_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')
    if batch_size % n_devices > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_summary_writer, eval_summary_writer = get_summary_writers(workdir)
    # Load dataset
    data_source = input_pipeline.DataSource(config)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds
    steps_per_epoch = data_source.ds_info.splits[
        'train'].num_examples // config.batch_size
    # Create dataset batch iterators
    train_iter = iter(train_ds)
    num_train_steps = train_ds.cardinality().numpy()
    steps_per_checkpoint = 1000

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert config.init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size]

    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng, dropout_rng = jax.random.split(rng, 3)

    initial_variables = model(config).init(
        {
            'params': init_rng,
            'dropout': dropout_rng
        }, init_batch, train=False)['params']
    optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
    optimizer = optimizer_def.create(initial_variables)

    optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables)
    ema = initial_variables
    step_offset = int(optimizer.state.step)

    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(train_step, config,
                                              learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, config=config),
                           axis_name='batch')

    # Gather metrics
    train_metrics = []

    for step, batch in zip(range(step_offset, num_train_steps), train_iter):
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)

        # Generate a PRNG key that will be rolled into the batch.
        rng, step_rng = jax.random.split(rng)
        sharded_rngs = common_utils.shard_prng_key(step_rng)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_rngs)
        train_metrics.append(metrics)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                            step)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for eval_batch in eval_ds:
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step +
                1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps:
            save_checkpoint(workdir, optimizer, ema, step)
Beispiel #18
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, rays.origins, rays.directions, rays.viewdirs),
            axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         init_lr=FLAGS.lr,
                                         decay_steps=FLAGS.lr_decay * 1000,
                                         decay_rate=0.1)
    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                summary_writer.image("pred_color", pred_color, step)
                summary_writer.image("pred_disp", pred_disp, step)
                summary_writer.image("pred_acc", pred_acc, step)
                summary_writer.image("target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("loss", stats[0].loss[0], step)
            summary_writer.scalar("psnr", stats[0].psnr[0], step)
            summary_writer.scalar("learning_rate", lr, step)
            if len(stats) > 1:
                summary_writer.scalar("loss_coarse", stats[1].loss[0], step)
                summary_writer.scalar("psnr_coarse", stats[1].psnr[0], step)
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"loss={stats[0].loss[0]:0.5f}, " +
                  f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
Beispiel #19
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  """

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(workdir)
        summary_writer.hparams(dict(config))

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.host_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder('imagenet2012:5.*.*')
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = steps_per_epoch * config.num_epochs
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    state = create_train_state(rng, config, model, image_size)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               config.num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, model.apply, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, model.apply),
                           axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main():
    args = parse_args()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", args.task_name)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = (args.train_file if args.train_file is not None else
                     args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in [
            "float32", "float64"
        ]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(args.model_name_or_path,
                                        num_labels=num_labels,
                                        finetuning_task=args.task_name)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    model = FlaxAutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path, config=config)

    # Preprocessing the datasets
    if args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [
            name for name in raw_datasets["train"].column_names
            if name != "label"
        ]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (model.config.label2id !=
            PretrainedConfig(num_labels=num_labels).label2id
            and args.task_name is not None and not is_regression):
        # Some have all caps in their config, some don't.
        label_name_to_id = {
            k.lower(): v
            for k, v in model.config.label2id.items()
        }
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!")
            label_to_id = {
                i: label_name_to_id[label_list[i]]
                for i in range(num_labels)
            }
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif args.task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = ((examples[sentence1_key], ) if sentence2_key is None else
                 (examples[sentence1_key], examples[sentence2_key]))
        result = tokenizer(*texts,
                           padding="max_length",
                           max_length=args.max_length,
                           truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names)

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["validation_matched" if args.task_name ==
                                      "mnli" else "validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    summary_writer = tensorboard.SummaryWriter(args.output_dir)
    summary_writer.hparams(vars(args))

    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(args.num_train_epochs)
    rng = jax.random.PRNGKey(args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = args.per_device_train_batch_size * jax.local_device_count(
    )
    eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count(
    )

    learning_rate_fn = create_learning_rate_fn(len(train_dataset),
                                               train_batch_size,
                                               args.num_train_epochs,
                                               args.num_warmup_steps,
                                               args.learning_rate)

    state = create_train_state(model,
                               learning_rate_fn,
                               is_regression,
                               num_labels=num_labels,
                               weight_decay=args.weight_decay)

    # define step functions
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step,
                            axis_name="batch",
                            donate_argnums=(0, ))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    if args.task_name is not None:
        metric = load_metric("glue", args.task_name)
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    for epoch in range(1, num_epochs + 1):
        logger.info(f"Epoch {epoch}")
        logger.info("  Training...")

        train_start = time.time()
        train_metrics = []
        rng, input_rng = jax.random.split(rng)

        # train
        for batch in glue_train_data_collator(input_rng, train_dataset,
                                              train_batch_size):
            state, metrics, dropout_rngs = p_train_step(
                state, batch, dropout_rngs)
            train_metrics.append(metrics)
        train_time += time.time() - train_start
        logger.info(f"    Done! Training metrics: {unreplicate(metrics)}")

        logger.info("  Evaluating...")

        # evaluate
        for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions),
                             references=chain(*labels))

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            labels = batch.pop("labels")
            predictions = eval_step(unreplicate(state), batch)
            metric.add_batch(predictions=predictions, references=labels)

        eval_metric = metric.compute()
        logger.info(f"    Done! Eval metrics: {eval_metric}")

        cur_step = epoch * (len(train_dataset) // train_batch_size)
        write_metric(train_metrics, eval_metric, train_time, cur_step)

    # save last checkpoint
    if jax.process_index() == 0:
        params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
        model.save_pretrained(args.output_dir, params=params)
Beispiel #21
0
def main():
    # region Argument parsing
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )
    # endregion

    # region Logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    # endregion

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # region Load Data
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(data_args.dataset_name,
                                    data_args.dataset_config_name,
                                    cache_dir=model_args.cache_dir)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]

        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        raw_datasets = load_dataset(extension,
                                    data_files=data_files,
                                    field="data",
                                    cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.
    # endregion

    # region Load pretrained model and tokenizer
    #
    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=True,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    # endregion

    # region Tokenizer check: this script requires a fast tokenizer.
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        raise ValueError(
            "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
            "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
            "requirement")
    # endregion

    # region Preprocessing the datasets
    # Preprocessing is slightly different for training and evaluation.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    else:
        column_names = raw_datasets["test"].column_names
    question_column_name = "question" if "question" in column_names else column_names[
        0]
    context_column_name = "context" if "context" in column_names else column_names[
        1]
    answer_column_name = "answers" if "answers" in column_names else column_names[
        2]

    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"

    if data_args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Training preprocessing
    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [
            q.lstrip() for q in examples[question_column_name]
        ]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[
                question_column_name if pad_on_right else context_column_name],
            examples[
                context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=data_args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right
                                                          else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else
                                                        0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char
                        and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[
                            token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(
                        token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(
                        token_end_index + 1)

        return tokenized_examples

    processed_raw_datasets = dict()
    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            # We will select sample from whole data if agument is specified
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))
        # Create train feature from dataset
        train_dataset = train_dataset.map(
            prepare_train_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_train_samples is not None:
            # Number of samples might increase during Feature Creation, We select only specified max samples
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))
        processed_raw_datasets["train"] = train_dataset

    # Validation preprocessing
    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [
            q.lstrip() for q in examples[question_column_name]
        ]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[
                question_column_name if pad_on_right else context_column_name],
            examples[
                context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=data_args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(
                examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    if training_args.do_eval:
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_examples = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            # We will select sample from whole data
            eval_examples = eval_examples.select(
                range(data_args.max_eval_samples))
        # Validation Feature Creation
        eval_dataset = eval_examples.map(
            prepare_validation_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_eval_samples is not None:
            # During Feature creation dataset samples might increase, we will select required samples again
            eval_dataset = eval_dataset.select(
                range(data_args.max_eval_samples))
        processed_raw_datasets["validation"] = eval_dataset

    if training_args.do_predict:
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_examples = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            # We will select sample from whole data
            predict_examples = predict_examples.select(
                range(data_args.max_predict_samples))
        # Predict Feature Creation
        predict_dataset = predict_examples.map(
            prepare_validation_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_predict_samples is not None:
            # During Feature creation dataset samples might increase, we will select required samples again
            predict_dataset = predict_dataset.select(
                range(data_args.max_predict_samples))
        processed_raw_datasets["test"] = predict_dataset
    # endregion

    # region Metrics and Post-processing:
    def post_processing_function(examples,
                                 features,
                                 predictions,
                                 stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=data_args.version_2_with_negative,
            n_best_size=data_args.n_best_size,
            max_answer_length=data_args.max_answer_length,
            null_score_diff_threshold=data_args.null_score_diff_threshold,
            output_dir=training_args.output_dir,
            prefix=stage,
        )
        # Format the result to the format the metric expects.
        if data_args.version_2_with_negative:
            formatted_predictions = [{
                "id": k,
                "prediction_text": v,
                "no_answer_probability": 0.0
            } for k, v in predictions.items()]
        else:
            formatted_predictions = [{
                "id": k,
                "prediction_text": v
            } for k, v in predictions.items()]

        references = [{
            "id": ex["id"],
            "answers": ex[answer_column_name]
        } for ex in examples]
        return EvalPrediction(predictions=formatted_predictions,
                              label_ids=references)

    metric = load_metric(
        "squad_v2" if data_args.version_2_with_negative else "squad")

    def compute_metrics(p: EvalPrediction):
        return metric.compute(predictions=p.predictions,
                              references=p.label_ids)

    # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
    def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
        """
        Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

        Args:
            start_or_end_logits(:obj:`tensor`):
                This is the output predictions of the model. We can only enter either start or end logits.
            eval_dataset: Evaluation dataset
            max_len(:obj:`int`):
                The maximum length of the output tensor. ( See the model.eval() part for more details )
        """

        step = 0
        # create a numpy array and fill it with -100.
        logits_concat = np.full((len(dataset), max_len),
                                -100,
                                dtype=np.float64)
        # Now since we have create an array now we will populate it with the outputs of the model.
        for i, output_logit in enumerate(
                start_or_end_logits):  # populate columns
            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
            # And after every iteration we have to change the step

            batch_size = output_logit.shape[0]
            cols = output_logit.shape[1]

            if step + batch_size < len(dataset):
                logits_concat[step:step + batch_size, :cols] = output_logit
            else:
                logits_concat[step:, :cols] = output_logit[:len(dataset) -
                                                           step]

            step += batch_size

        return logits_concat

    # endregion

    # region Training steps and logging init
    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(training_args.output_dir)
            summary_writer.hparams({
                **training_args.to_dict(),
                **vars(model_args),
                **vars(data_args)
            })
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    def write_train_metric(summary_writer, train_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    def write_eval_metric(summary_writer, eval_metrics, step):
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count(
    )
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count(
    )
    # endregion

    # region Load model
    model = FlaxAutoModelForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        seed=training_args.seed,
        dtype=getattr(jnp, model_args.dtype),
    )

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    state = create_train_state(model,
                               learning_rate_fn,
                               num_labels=max_seq_length,
                               training_args=training_args)

    # endregion

    # region Define train step functions
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        start_positions = batch.pop("start_positions")
        end_positions = batch.pop("end_positions")
        targets = (start_positions, end_positions)

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step,
                            axis_name="batch",
                            donate_argnums=(0, ))

    # endregion

    # region Define eval step functions
    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")
    # endregion

    # region Define train and eval loop
    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    train_time = 0
    step_per_epoch = len(train_dataset) // train_batch_size
    total_steps = step_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:

        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # train
        for step, batch in enumerate(
                tqdm(
                    train_data_collator(input_rng, train_dataset,
                                        train_batch_size),
                    total=step_per_epoch,
                    desc="Training...",
                    position=1,
                ),
                1,
        ):
            state, train_metric, dropout_rngs = p_train_step(
                state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * step_per_epoch + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if (training_args.do_eval
                    and (cur_step % training_args.eval_steps == 0
                         or cur_step % step_per_epoch == 0) and cur_step > 0):

                eval_metrics = {}
                all_start_logits = []
                all_end_logits = []
                # evaluate
                for batch in tqdm(
                        eval_data_collator(eval_dataset, eval_batch_size),
                        total=len(eval_dataset) // eval_batch_size,
                        desc="Evaluating ...",
                        position=2,
                ):
                    _ = batch.pop("example_id")
                    _ = batch.pop("offset_mapping")
                    predictions = p_eval_step(state, batch)
                    start_logits = np.array(
                        [pred for pred in chain(*predictions[0])])
                    end_logits = np.array(
                        [pred for pred in chain(*predictions[1])])
                    all_start_logits.append(start_logits)
                    all_end_logits.append(end_logits)

                # evaluate also on leftover examples (not divisible by batch_size)
                num_leftover_samples = len(eval_dataset) % eval_batch_size

                # make sure leftover batch is evaluated on one device
                if num_leftover_samples > 0 and jax.process_index() == 0:
                    # take leftover samples
                    batch = eval_dataset[-num_leftover_samples:]
                    batch = {k: np.array(v) for k, v in batch.items()}
                    _ = batch.pop("example_id")
                    _ = batch.pop("offset_mapping")

                    predictions = eval_step(unreplicate(state), batch)
                    start_logits = np.array([pred for pred in predictions[0]])
                    end_logits = np.array([pred for pred in predictions[1]])
                    all_start_logits.append(start_logits)
                    all_end_logits.append(end_logits)

                max_len = max([x.shape[1] for x in all_start_logits
                               ])  # Get the max_length of the tensor

                # concatenate the numpy array
                start_logits_concat = create_and_fill_np_array(
                    all_start_logits, eval_dataset, max_len)
                end_logits_concat = create_and_fill_np_array(
                    all_end_logits, eval_dataset, max_len)

                # delete the list of numpy arrays
                del all_start_logits
                del all_end_logits
                outputs_numpy = (start_logits_concat, end_logits_concat)
                prediction = post_processing_function(eval_examples,
                                                      eval_dataset,
                                                      outputs_numpy)
                eval_metrics = compute_metrics(prediction)

                logger.info(
                    f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})"
                )

                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0
                    and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)
        epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
    # endregion

    # Eval after training
    if training_args.do_eval:
        eval_metrics = {}
        all_start_logits = []
        all_end_logits = []

        eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
        for batch in tqdm(eva_loader,
                          total=len(eval_dataset) // eval_batch_size,
                          desc="Evaluating ...",
                          position=2):
            _ = batch.pop("example_id")
            _ = batch.pop("offset_mapping")
            predictions = p_eval_step(state, batch)
            start_logits = np.array([pred for pred in chain(*predictions[0])])
            end_logits = np.array([pred for pred in chain(*predictions[1])])
            all_start_logits.append(start_logits)
            all_end_logits.append(end_logits)

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: np.array(v) for k, v in batch.items()}
            _ = batch.pop("example_id")
            _ = batch.pop("offset_mapping")

            predictions = eval_step(unreplicate(state), batch)
            start_logits = np.array([pred for pred in predictions[0]])
            end_logits = np.array([pred for pred in predictions[1]])
            all_start_logits.append(start_logits)
            all_end_logits.append(end_logits)

        max_len = max([x.shape[1] for x in all_start_logits
                       ])  # Get the max_length of the tensor

        # concatenate the numpy array
        start_logits_concat = create_and_fill_np_array(all_start_logits,
                                                       eval_dataset, max_len)
        end_logits_concat = create_and_fill_np_array(all_end_logits,
                                                     eval_dataset, max_len)

        # delete the list of numpy arrays
        del all_start_logits
        del all_end_logits
        outputs_numpy = (start_logits_concat, end_logits_concat)
        prediction = post_processing_function(eval_examples, eval_dataset,
                                              outputs_numpy)
        eval_metrics = compute_metrics(prediction)

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Beispiel #22
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())

    scene_path_list = train_utils.get_train_scene_list(config)

    train_ds = datasets.create_train_dataset(config, scene_path_list[0])
    _, eval_ds_dict = datasets.create_eval_dataset(config)
    _, eval_ds = eval_ds_dict.popitem()
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.

    # check if the job was stopped and relaunced
    latest_ckpt = checkpoints.latest_checkpoint(workdir)
    if latest_ckpt is None:
        # No previous checkpoint. Then check for pretrained weights.
        if config.train.pretrain_dir:
            state = checkpoints.restore_checkpoint(config.train.pretrain_dir,
                                                   state)
    else:
        state = checkpoints.restore_checkpoint(workdir, state)

    initial_step = int(state.step) + 1
    step_per_scene = config.train.switch_scene_iter
    if config.dev_run:
        jnp.set_printoptions(precision=2)
        np.set_printoptions(precision=2)
        step_per_scene = 3

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            if step % step_per_scene == 0:
                scene_idx = np.random.randint(len(scene_path_list))
                logging.info("Loading scene {}".format(
                    scene_path_list[scene_idx]))  # pylint: disable=logging-format-interpolation
                curr_scene = scene_path_list[scene_idx]
                if config.dataset.name == "dtu":
                    # lighting can take values between 0 and 6 (both included)
                    config.dataset.dtu_light_idx = np.random.randint(low=0,
                                                                     high=7)
                train_ds = datasets.create_train_dataset(config, curr_scene)
                ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)

            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = 0.
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)
Beispiel #23
0
def experiment(
        model_dir='.',  # pylint: disable=dangerous-default-value
        imagenet_subset_dir=None,
        dataset='cifar10',
        batch_size=256,
        eval_batch_size=1024,
        num_epochs=200,
        learning_rate=0.1,
        aug_imagenet_apply_colour_jitter=False,
        aug_imagenet_greyscale_prob=0.0,
        sgd_momentum=0.9,
        sgd_nesterov=True,
        lr_schedule='stepped',
        lr_sched_steps=[[60, 0.2], [120, 0.04], [160, 0.008]],
        lr_sched_halfcoslength=400.0,
        lr_sched_warmup=5.0,
        l2_reg=0.0005,
        weight_decay=0.0,
        architecture='wrn22_10',
        n_val=5000,
        n_sup=1000,
        teacher_alpha=0.999,
        anneal_teacher_alpha=False,
        unsupervised_regularizer='none',
        cons_weight=1.0,
        conf_thresh=0.97,
        conf_avg=False,
        cut_backg_noise=1.0,
        cut_prob=1.0,
        box_reg_scale_mode='fixed',
        box_reg_scale=0.25,
        box_reg_random_aspect_ratio=False,
        cow_sigma_range=(4.0, 8.0),
        cow_prop_range=(0.25, 1.0),
        mix_regularizer='none',
        mix_aug_separately=False,
        mix_logits=True,
        mix_weight=1.0,
        mix_conf_thresh=0.97,
        mix_conf_avg=True,
        mix_conf_mode='mix_prob',
        ict_alpha=0.1,
        mix_box_reg_scale_mode='fixed',
        mix_box_reg_scale=0.25,
        mix_box_reg_random_aspect_ratio=False,
        mix_cow_sigma_range=(4.0, 8.0),
        mix_cow_prop_range=(0.0, 1.0),
        subset_seed=12345,
        val_seed=131,
        run_seed=None,
        log_fn=print,
        checkpoints='on',
        on_epoch_finished_fn=None,
        debug=False):
    """Run experiment."""
    if checkpoints not in {'none', 'on', 'retain'}:
        raise ValueError('checkpoints should be one of (none|on|retain)')

    if checkpoints != 'none':
        checkpoint_path = os.path.join(model_dir, 'checkpoint.pkl')
        checkpoint_new_path = os.path.join(model_dir, 'checkpoint.pkl.new')
    else:
        checkpoint_path = None
        checkpoint_new_path = None

    if dataset not in {'svhn', 'cifar10', 'cifar100', 'imagenet'}:
        raise ValueError('Unknown dataset \'{}\''.format(dataset))

    if architecture not in {
            'wrn20_10', 'wrn26_10', 'wrn26_2', 'wrn20_6_shakeshake',
            'wrn26_6_shakeshake', 'wrn26_2_shakeshake', 'pyramid', 'resnet50',
            'resnet101', 'resnet152', 'resnet50x2', 'resnet101x2',
            'resnet152x2', 'resnet50x4', 'resnet101x4', 'resnet152x4',
            'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d'
    }:
        raise ValueError('Unknown architecture \'{}\''.format(architecture))

    if lr_schedule not in {'constant', 'stepped', 'cosine'}:
        raise ValueError('Unknown LR schedule \'{}\''.format(lr_schedule))

    if mix_conf_mode not in {'mix_prob', 'mix_conf'}:
        raise ValueError('Unknown mix_conf_mode \'{}\''.format(mix_conf_mode))

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(model_dir)
    else:
        summary_writer = None

    unsup_reg, augment_twice = build_pert_reg(
        unsupervised_regularizer,
        cut_backg_noise=cut_backg_noise,
        cut_prob=cut_prob,
        box_reg_scale_mode=box_reg_scale_mode,
        box_reg_scale=box_reg_scale,
        box_reg_random_aspect_ratio=box_reg_random_aspect_ratio,
        cow_sigma_range=cow_sigma_range,
        cow_prop_range=cow_prop_range)

    mix_reg = build_mix_reg(
        mix_regularizer,
        ict_alpha=ict_alpha,
        box_reg_scale_mode=mix_box_reg_scale_mode,
        box_reg_scale=mix_box_reg_scale,
        box_reg_random_aspect_ratio=mix_box_reg_random_aspect_ratio,
        cow_sigma_range=mix_cow_sigma_range,
        cow_prop_range=mix_cow_prop_range)

    if run_seed is None:
        run_seed = subset_seed << 32 | n_val
    train_rng = jax.random.PRNGKey(run_seed)
    init_rng, train_rng = jax.random.split(train_rng)

    if batch_size % jax.device_count() > 0:
        raise ValueError('Train batch size must be divisible by the number of '
                         'devices')
    if eval_batch_size % jax.device_count() > 0:
        raise ValueError('Eval batch size must be divisible by the number of '
                         'devices')
    local_batch_size = batch_size // jax.host_count()
    local_eval_batch_size = eval_batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    if dataset == 'svhn':
        image_size = 32
        top5_err_required = False
        data_source = small_image_data_source.SVHNDataSource(
            n_val=n_val,
            n_sup=n_sup,
            train_batch_size=local_batch_size,
            eval_batch_size=local_eval_batch_size,
            augment_twice=augment_twice,
            subset_seed=subset_seed,
            val_seed=val_seed)
    elif dataset == 'cifar10':
        image_size = 32
        top5_err_required = False
        data_source = small_image_data_source.CIFAR10DataSource(
            n_val=n_val,
            n_sup=n_sup,
            train_batch_size=local_batch_size,
            eval_batch_size=local_eval_batch_size,
            augment_twice=augment_twice,
            subset_seed=subset_seed,
            val_seed=val_seed)
    elif dataset == 'cifar100':
        image_size = 32
        top5_err_required = False
        data_source = small_image_data_source.CIFAR100DataSource(
            n_val=n_val,
            n_sup=n_sup,
            train_batch_size=local_batch_size,
            eval_batch_size=local_eval_batch_size,
            augment_twice=augment_twice,
            subset_seed=subset_seed,
            val_seed=val_seed)
    elif dataset == 'imagenet':
        image_size = 224
        top5_err_required = True
        if imagenet_subset_dir is None:
            raise ValueError(
                'Please provide a directory to the imagenet_subset_dir '
                'command line arg to specify where the ImageNet '
                'subsets are stored')
        data_source = imagenet_data_source.ImageNetDataSource(
            imagenet_subset_dir,
            n_val,
            n_sup,
            local_batch_size,
            local_eval_batch_size,
            augment_twice,
            apply_colour_jitter=aug_imagenet_apply_colour_jitter,
            greyscale_prob=aug_imagenet_greyscale_prob,
            load_test_set=(n_val == 0),
            image_size=image_size,
            subset_seed=subset_seed,
            val_seed=val_seed)
    else:
        raise RuntimeError

    n_train = data_source.n_train
    train_ds = data_source.train_semisup_ds

    if n_val == 0:
        eval_ds = data_source.test_ds
        n_eval = data_source.n_test
    else:
        eval_ds = data_source.val_ds
        n_eval = data_source.n_val

    log_fn(
        'DATA: |train|={}, |sup|={}, |eval|={}, (|val|={}, |test|={})'.format(
            data_source.n_train, data_source.n_sup, n_eval, data_source.n_val,
            data_source.n_test))

    log_fn('Loaded dataset')

    steps_per_epoch = n_train // batch_size
    steps_per_eval = n_eval // eval_batch_size
    if n_eval % eval_batch_size > 0:
        steps_per_eval += 1
    num_steps = steps_per_epoch * num_epochs

    # Create model
    model_stu, state_stu = create_model(init_rng, architecture,
                                        device_batch_size, image_size,
                                        data_source.n_classes)
    state_stu = jax_utils.replicate(state_stu)
    log_fn('Built model')

    # Create optimizer
    optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                   beta=sgd_momentum,
                                   nesterov=sgd_nesterov)

    optimizer_stu = optimizer_def.create(model_stu)
    optimizer_stu = optimizer_stu.replicate()
    del model_stu  # don't keep a copy of the initial model

    # Create learning rate function
    base_learning_rate = learning_rate * batch_size / 256.
    if lr_schedule == 'constant':
        learning_rate_fn = create_constant_learning_rate_fn(base_learning_rate)
    elif lr_schedule == 'stepped':
        learning_rate_fn = create_stepped_learning_rate_fn(
            base_learning_rate,
            steps_per_epoch,
            lr_sched_steps=lr_sched_steps,
            warmup_length=lr_sched_warmup)
    elif lr_schedule == 'cosine':
        learning_rate_fn = create_cosine_learning_rate_fn(
            base_learning_rate,
            steps_per_epoch,
            halfcoslength_epochs=lr_sched_halfcoslength,
            warmup_length=lr_sched_warmup)
    else:
        raise RuntimeError

    if anneal_teacher_alpha:
        if lr_schedule == 'constant':
            one_minus_alpha_fn = create_constant_learning_rate_fn(
                1.0 - teacher_alpha)
        elif lr_schedule == 'stepped':
            one_minus_alpha_fn = create_stepped_learning_rate_fn(
                1.0 - teacher_alpha,
                steps_per_epoch,
                lr_sched_steps=lr_sched_steps)
        elif lr_schedule == 'cosine':
            one_minus_alpha_fn = create_cosine_learning_rate_fn(
                1.0 - teacher_alpha,
                steps_per_epoch,
                halfcoslength_epochs=lr_sched_halfcoslength)
        else:
            raise RuntimeError
        teacher_alpha_fn = lambda step: 1.0 - one_minus_alpha_fn(step)
    else:
        teacher_alpha_fn = lambda step: teacher_alpha

    log_fn('Built optimizer')

    # Teacher model is just the student as we duplicate it when we modify it
    model_tea = optimizer_stu.target
    # Replicate batch stats
    state_tea = jax.tree_map(lambda x: x, state_stu)

    # Set up epoch and step counter
    # Load existing checkpoint if available
    epoch = 1
    step = 0

    if checkpoints != 'none':
        if tf.io.gfile.exists(checkpoint_path):
            with tf.io.gfile.GFile(checkpoint_path, 'rb') as f_in:
                check = pickle.load(f_in)

                # Student optimizer and batch stats
                optimizer_stu = util.restore_state_list(
                    optimizer_stu, check['optimizer_stu'])

                state_stu = util.restore_state_list(state_stu,
                                                    check['state_stu'])

                # Teacher model and batch stats
                model_tea = util.restore_state_list(model_tea,
                                                    check['model_tea'])

                state_tea = util.restore_state_list(state_tea,
                                                    check['state_tea'])

                epoch = check['epoch']
                step = check['step']

                log_fn('Loaded checkpoint from {}'.format(checkpoint_path))

    #
    # Training and evaluation step functions
    #
    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        l2_reg=l2_reg,
        weight_decay=weight_decay,
        teacher_alpha_fn=teacher_alpha_fn,
        unsup_reg=unsup_reg,
        cons_weight=cons_weight,
        conf_thresh=conf_thresh,
        conf_avg=conf_avg,
        mix_reg=mix_reg,
        mix_aug_separately=mix_aug_separately,
        mix_logits=mix_logits,
        mix_weight=mix_weight,
        mix_conf_thresh=mix_conf_thresh,
        mix_conf_avg=mix_conf_avg,
        mix_conf_mode=mix_conf_mode),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eval_top_5=top5_err_required),
                           axis_name='batch')

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    #
    # Training loop
    #

    log_fn('Training...')
    epoch_metrics_stu = []
    t1 = time.time()
    while step < num_steps:
        train_rng, iter_rng = jax.random.split(train_rng)
        batch = next(train_iter)
        batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
        batch = shard(batch, iter_rng)

        optimizer_stu, state_stu, metrics_stu, model_tea, state_tea = p_train_step(
            optimizer_stu, state_stu, model_tea, state_tea, batch)

        if debug:
            log_fn('Step {} time {}'.format(step, time.time() - t1))

        epoch_metrics_stu.append(metrics_stu)
        if (step + 1) % steps_per_epoch == 0:
            epoch_metrics_stu = util.get_metrics(epoch_metrics_stu)
            train_epoch_metrics = jax.tree_map(lambda x: x.mean(),
                                               epoch_metrics_stu)
            if summary_writer is not None:
                for key, vals in epoch_metrics_stu.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)

            epoch_metrics_stu = []
            eval_stu_metrics = []
            eval_tea_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # TF to NumPy
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                # Pad short batches
                eval_batch = util.pad_classification_batch(
                    eval_batch, local_eval_batch_size)
                # Shard across local devices
                eval_batch = shard(eval_batch)
                metrics_stu = p_eval_step(optimizer_stu.target, state_stu,
                                          eval_batch)
                metrics_tea = p_eval_step(model_tea, state_tea, eval_batch)
                eval_stu_metrics.append(metrics_stu)
                eval_tea_metrics.append(metrics_tea)
            eval_stu_metrics = util.get_metrics(eval_stu_metrics)
            eval_tea_metrics = util.get_metrics(eval_tea_metrics)
            eval_stu_epoch_metrics = jax.tree_map(lambda x: x.sum(),
                                                  eval_stu_metrics)
            eval_tea_epoch_metrics = jax.tree_map(lambda x: x.sum(),
                                                  eval_tea_metrics)
            eval_stu_epoch_metrics = avg_eval_metrics(eval_stu_epoch_metrics)
            eval_tea_epoch_metrics = avg_eval_metrics(eval_tea_epoch_metrics)

            t2 = time.time()

            if top5_err_required:
                log_fn(
                    'EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
                    'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
                    'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
                    'top-5-err={:.3%}, TEA Eval loss={:.6f}, err={:.3%}, '
                    'top-5-err={:.3%}'.format(
                        epoch,
                        t2 - t1,
                        train_epoch_metrics['loss'],
                        train_epoch_metrics['error_rate'],
                        train_epoch_metrics['cons_loss'],
                        train_epoch_metrics['conf_rate'],
                        train_epoch_metrics['mix_loss'],
                        train_epoch_metrics['mix_conf_rate'],
                        eval_stu_epoch_metrics['loss'],
                        eval_stu_epoch_metrics['error_rate'],
                        eval_stu_epoch_metrics['top5_error_rate'],
                        eval_tea_epoch_metrics['loss'],
                        eval_tea_epoch_metrics['error_rate'],
                        eval_tea_epoch_metrics['top5_error_rate'],
                    ))
            else:
                log_fn(
                    'EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
                    'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
                    'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
                    'TEA Eval loss={:.6f}, err={:.3%}'.format(
                        epoch,
                        t2 - t1,
                        train_epoch_metrics['loss'],
                        train_epoch_metrics['error_rate'],
                        train_epoch_metrics['cons_loss'],
                        train_epoch_metrics['conf_rate'],
                        train_epoch_metrics['mix_loss'],
                        train_epoch_metrics['mix_conf_rate'],
                        eval_stu_epoch_metrics['loss'],
                        eval_stu_epoch_metrics['error_rate'],
                        eval_tea_epoch_metrics['loss'],
                        eval_tea_epoch_metrics['error_rate'],
                    ))

            if on_epoch_finished_fn is not None:
                if top5_err_required:
                    on_epoch_finished_fn(
                        epoch,
                        eval_stu_err=eval_stu_epoch_metrics['error_rate'],
                        eval_tea_err=eval_tea_epoch_metrics['error_rate'],
                        eval_stu_top5_err=eval_stu_epoch_metrics[
                            'top5_error_rate'],
                        eval_tea_top5_err=eval_tea_epoch_metrics[
                            'top5_error_rate'],
                    )
                else:
                    on_epoch_finished_fn(
                        epoch,
                        eval_stu_err=eval_stu_epoch_metrics['error_rate'],
                        eval_tea_err=eval_tea_epoch_metrics['error_rate'],
                    )

            t1 = t2

            if summary_writer is not None:
                summary_writer.scalar('eval_stu_loss',
                                      eval_stu_epoch_metrics['loss'], epoch)
                summary_writer.scalar('eval_stu_error_rate',
                                      eval_stu_epoch_metrics['error_rate'],
                                      epoch)
                summary_writer.scalar('eval_tea_loss',
                                      eval_tea_epoch_metrics['loss'], epoch)
                summary_writer.scalar('eval_tea_error_rate',
                                      eval_tea_epoch_metrics['error_rate'],
                                      epoch)
                if top5_err_required:
                    summary_writer.scalar(
                        'eval_stu_top5_error_rate',
                        eval_stu_epoch_metrics['top5_error_rate'], epoch)
                    summary_writer.scalar(
                        'eval_tea_top5_error_rate',
                        eval_tea_epoch_metrics['top5_error_rate'], epoch)
                summary_writer.flush()

                epoch += 1

                if checkpoints != 'none':
                    if jax.host_id() == 0:
                        # Write to new checkpoint file so that we don't immediately
                        # overwrite the old one
                        with tf.io.gfile.GFile(checkpoint_new_path,
                                               'wb') as f_out:
                            check = dict(
                                optimizer_stu=util.to_state_list(
                                    optimizer_stu),
                                state_stu=util.to_state_list(state_stu),
                                model_tea=util.to_state_list(model_tea),
                                state_tea=util.to_state_list(state_tea),
                                epoch=epoch,
                                step=step + 1,
                            )
                            pickle.dump(check, f_out)
                            del check
                        # Remove old checkpoint and rename
                        if tf.io.gfile.exists(checkpoint_path):
                            tf.io.gfile.remove(checkpoint_path)
                        tf.io.gfile.rename(checkpoint_new_path,
                                           checkpoint_path)

        step += 1

    if checkpoints == 'on':
        if jax.host_id() == 0:
            if tf.io.gfile.exists(checkpoint_path):
                tf.io.gfile.remove(checkpoint_path)
Beispiel #24
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    Final TrainState.
  """

    writer = metric_writers.create_default_writer(
        logdir=workdir, just_logging=jax.host_id() != 0)

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.process_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder(config.dataset)
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = int(steps_per_epoch * config.num_epochs)
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    learning_rate_fn = create_learning_rate_fn(config, base_learning_rate,
                                               steps_per_epoch)

    state = create_train_state(rng, config, model, image_size,
                               learning_rate_fn)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    train_metrics = []
    hooks = []
    if jax.process_index() == 0:
        hooks += [
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics_last_t = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        if step == step_offset:
            logging.info('Initial compilation completed.')

        if config.get('log_every_steps'):
            train_metrics.append(metrics)
            if (step + 1) % config.log_every_steps == 0:
                train_metrics = common_utils.get_metrics(train_metrics)
                summary = {
                    f'train_{k}': v
                    for k, v in jax.tree_map(lambda x: x.mean(),
                                             train_metrics).items()
                }
                summary['steps_per_second'] = config.log_every_steps / (
                    time.time() - train_metrics_last_t)
                writer.write_scalars(step + 1, summary)
                train_metrics = []
                train_metrics_last_t = time.time()

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            writer.write_scalars(
                step + 1, {f'eval_{key}': val
                           for key, val in summary.items()})
            writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

    return state
Beispiel #25
0
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {
            "loss": loss.sum(),
            "accuracy": accuracy.sum(),
            "normalizer": label_mask.sum()
        }
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
Beispiel #26
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    hparam_str_dict = json.loads(FLAGS.xm_parameters)
    hparam_str = ','.join([
        '%s=%s' % (shorten(k), str(hparam_str_dict[k]))
        for k in hparam_str_dict.keys()
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    predict_io_shape = (FLAGS.per_device_batch_size,
                        FLAGS.num_strings_per_task,
                        FLAGS.predict_max_characters)
    target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']:
        spec_vocab = robust_fill_dsl.CHARACTER + input_pipeline.SEPARATOR_TOKEN
        spec_id_token_table = {
            i + 3: token
            for i, token in enumerate(spec_vocab)
        }
        bos_id = 1
        eos_id = 2
        spec_id_token_table[bos_id] = robust_fill_dsl.BOS
        spec_id_token_table[eos_id] = robust_fill_dsl.EOS
        spec_token_id_table = {
            token: id
            for id, token in spec_id_token_table.items()
        }
        spec_vocab_size = len(spec_token_id_table) + 1  # For padding.
        program_id_token_table, _ = dsl_tokens.build_token_tables()
        program_vocab_size = len(program_id_token_table) + 1
    elif FLAGS.dataset_type == 'scan':
        # TODO(jxihong): Scan is not handled yet.
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))
    else:
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))

    # Parse io and program token sequences (for eval).
    def decode_io(inputs, outputs):
        """Convert from int tensors to strings."""
        if FLAGS.dataset_type == 'robust_fill':

            def decode_str(s):
                """Decode string tokens."""
                return ''.join(
                    [spec_id_token_table[t_id] for t_id in s if t_id > 0])

            inps, outs = [], []
            for inp, out in zip(inputs, outputs):
                inps.append(decode_str(inp))
                outs.append(decode_str(out))
            return inps, outs

        elif FLAGS.dataset_type == 'scan':

            def decode_str(s):
                """Decode string tokens."""
                return ' '.join(
                    [spec_id_token_table[t_id] for t_id in s if t_id > 0])

            inps = [decode_str(inp) for inp in inputs]
            dummy_outs = [''] * len(inps)
            return inps, dummy_outs

        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_spec(target):
        """Convert from int tensor to a string."""
        target = target[:np.argmax(target == eos_id)].astype(np.int32)

        if FLAGS.dataset_type == 'robust_fill':
            target = target[target != bos_id].tolist()
            return ''.join(
                [spec_id_token_table[t_id] for t_id in target if t_id > 0])
        elif FLAGS.dataset_type == 'scan':
            # TODO(jxihong): Scan is not handled yet.
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))
        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_program(program):
        """Decode program tokens into a program (program object or string)."""
        program = program[:np.argmax(program == eos_id) + 1].astype(np.int32)

        if FLAGS.dataset_type == 'robust_fill':
            # Returns either a Concat program object, or None.
            program = program[program != bos_id].tolist()
            try:
                return robust_fill_dsl.decode_program(program,
                                                      program_id_token_table)
            except:  # pylint: disable=bare-except
                return None  # Program does not compile.
        elif FLAGS.dataset_type == 'scan':
            # Returns a string.
            program = program[jnp.logical_and(program != bos_id,
                                              program != eos_id)].tolist()
            return ' '.join(scan_vocab.decode(program, program_id_token_table))
        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_program_str(program):  # pylint: disable=unused-variable
        """Decode program tokens into a string."""
        decoded = decode_program(program)
        if FLAGS.dataset_type == 'robust_fill':
            try:
                return decoded.to_string()  # pytype: disable=attribute-error
            except:  # pylint: disable=bare-except
                return 'did not compile'
        else:
            assert isinstance(decoded,
                              str), '{} should be string'.format(decoded)
            return decoded

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
    padded_shapes = {
        'inputs': io_shape[1:],
        'outputs': io_shape[1:],
        'target': target_shape[1:],
    }
    logging.info('padded_shapes: %s', padded_shapes)

    if FLAGS.dataset_type == 'robust_fill':
        if FLAGS.model_type == 'spec_decomposer_model':
            create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_spec_decomposer_model
        elif FLAGS.model_type == 'synthesizer_model':
            create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_synthesizer_model
        else:
            raise ValueError(f'Unhandled model_type: {FLAGS.model_type}')

    elif FLAGS.dataset_type == 'scan':
        raise NotImplementedError()  # TODO(kshi): Implement.
        # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record
    else:
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))

    dataset = create_dataset_fn(FLAGS.dataset_filepattern, spec_token_id_table,
                                FLAGS.num_strings_per_task)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=padded_shapes,
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_padded_shapes = padded_shapes.copy()
    predict_padded_shapes['inputs'] = predict_io_shape[1:]
    predict_padded_shapes['outputs'] = predict_io_shape[1:]
    logging.info('predict_padded_shapes: %s', predict_padded_shapes)
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)
    train_ds = dataset.skip(FLAGS.num_eval_steps)
    if FLAGS.train_set_batches > 0:
        train_ds = train_ds.take(FLAGS.train_set_batches)
    train_ds = train_ds.repeat()

    test_dataset = create_dataset_fn(FLAGS.test_dataset_filepattern,
                                     spec_token_id_table,
                                     FLAGS.num_strings_per_task)
    test_dataset = test_dataset.padded_batch(
        batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False)
    quick_test_dataset = (test_dataset.take(
        FLAGS.num_quick_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))
    final_test_dataset = (test_dataset.take(
        FLAGS.num_final_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    if FLAGS.model_type == 'spec_decomposer_model':
        output_vocab_size = spec_vocab_size
    elif FLAGS.model_type == 'synthesizer_model':
        output_vocab_size = program_vocab_size
    else:
        raise ValueError(f'Unhandled model_type: {FLAGS.model_type}')

    base_config = base_models.TransformerConfig(
        vocab_size=spec_vocab_size,
        output_vocab_size=output_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_id,
        num_input_relative_position_buckets=FLAGS.num_position_buckets,
        max_input_distance=FLAGS.max_distance,
        num_output_relative_position_buckets=FLAGS.num_position_buckets,
        max_output_distance=FLAGS.max_distance,
        num_input_cross_output_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_input_cross_output_distance=FLAGS.max_distance,
        num_program_relative_position_buckets=FLAGS.num_position_buckets,
        max_program_distance=FLAGS.max_distance,
        num_program_cross_embed_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_program_cross_embed_distance=FLAGS.
        max_program_cross_embed_distance,
        num_flat_encoding_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_flat_encoding_distance=FLAGS.max_distance)
    train_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config,
        dataset_type=FLAGS.dataset_type,
        flat_encoded_self_attention=FLAGS.flat_encoded_self_attention)
    eval_config = train_config.replace(base_config=base_config.replace(
        deterministic=True))
    predict_config = train_config.replace(base_config=base_config.replace(
        shift=False,
        deterministic=True,
        decode=not FLAGS.slow_decode,
        max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)))

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    dropout_rng = jax.random.split(rng, jax.local_device_count())
    del rng

    m = models.DecomposeAttentionTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)
        if FLAGS.finetune_start_step > 0:
            logging.info(
                'Checking that start_step (%s) == finetune_start_step (%s)',
                start_step, FLAGS.finetune_start_step)
            assert start_step >= FLAGS.finetune_start_step
            steps_to_skip = start_step - FLAGS.finetune_start_step
        else:
            steps_to_skip = start_step

        # TODO(kshi): It is likely that this code can lead to the job stalling for
        # 10+ hours when restarting from a checkpoint that had been trained a long
        # time, possibly because dataset skipping is slow.
        logging.info('Skipping %s steps...', steps_to_skip)
        train_ds = train_ds.skip(steps_to_skip)
        dummy_p_train_step = jax.pmap(
            lambda dropout_rng: jax.random.split(dropout_rng)[1])
        for _ in range(steps_to_skip):
            dropout_rng = dummy_p_train_step(dropout_rng)
        logging.info('Finished skipping steps')
        logging.info('Host %s has dropout_rng = %s', jax.host_id(),
                     dropout_rng)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    # TODO(jxihong): Implement fast decoding.
    assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

    if FLAGS.finetune_start_step <= 0:
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr)
    else:
        # Constant LR for finetuning.
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr, factors='constant')
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eos_token=eos_id,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_target_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        eos_token=eos_id,
        max_decode_len=FLAGS.max_target_length,
        config=predict_config,
        slow_decode=FLAGS.slow_decode),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, ))

    # Main Train Loop
    # ---------------------------------------------------------------------------

    logging.info('Starting training!')
    metrics_all = []
    tick = time.time()
    train_iter = train_ds.as_numpy_iterator()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, targets = load_data(next(train_iter))

        optimizer, metrics, dropout_rng = p_train_step(optimizer,
                                                       inputs,
                                                       outputs,
                                                       targets,
                                                       dropout_rng=dropout_rng)
        metrics_all.append(metrics)
        is_last_step = step == FLAGS.num_train_steps - 1

        # Periodic metric handling.

        # Training Metrics
        if (step and step % FLAGS.log_freq == 0) or is_last_step:
            logging.info('Gathering training metrics.')
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums)
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)

            if jax.host_id() == 0:
                logging.info('Train in step: %d, loss: %.4f', step,
                             summary['loss'])
                tock = time.time()
                steps_per_sec = FLAGS.log_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('train/steps per second', steps_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar('train/' + key, val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

        # Evaluation Metrics
        if (step and step % FLAGS.eval_freq == 0) or is_last_step:
            logging.info('Gathering evaluation metrics.')
            t_evaluation_start = time.time()
            eval_metrics = []
            for batches in eval_ds.as_numpy_iterator():
                inputs, outputs, targets = load_data(batches)

                metrics = p_eval_step(optimizer.target, inputs, outputs,
                                      targets)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            if jax.host_id() == 0:
                logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                             time.time() - t_evaluation_start, step,
                             eval_summary['loss'])
                for key, val in eval_summary.items():
                    summary_writer.scalar('eval/' + key, val, step)
                summary_writer.flush()

        # Beam search metrics.
        if (step and step % FLAGS.predict_freq == 0) or is_last_step:
            logging.info('Gathering beam search metrics.')
            test_ds = final_test_dataset if is_last_step else quick_test_dataset

            for dataset, predict_or_test in [(predict_ds, 'predict'),
                                             (test_ds, 'test')]:

                for beam_size in [1, 10]:
                    t_inference_start = time.time()
                    total_successes = 0
                    total_denominator = 0

                    ios, targets_list, predictions, top_of_beams, scores = ([],
                                                                            [],
                                                                            [],
                                                                            [],
                                                                            [])
                    for batches in dataset.as_numpy_iterator():
                        pred_batch = batches
                        # Handle final odd-sized batch by padding instead of dropping it.
                        cur_pred_batch_size = pred_batch['inputs'].shape[0]
                        if cur_pred_batch_size % n_devices:
                            padded_size = int(
                                np.ceil(cur_pred_batch_size / n_devices) *
                                n_devices)
                            # pylint: disable=cell-var-from-loop
                            pred_batch = jax.tree_map(
                                lambda x: pad_examples(x, padded_size),
                                pred_batch)
                        inputs, outputs, targets = load_data(pred_batch)

                        cache = (p_init_cache(inputs, outputs, targets)
                                 if not FLAGS.slow_decode else None)
                        predicted = p_pred_step(optimizer.target, inputs,
                                                outputs, cache, beam_size)
                        predicted = tohost(predicted)
                        inputs, outputs, targets = map(
                            tohost, (inputs, outputs, targets))

                        for i, beams in enumerate(predicted):
                            inps, outs = decode_io(inputs[i], outputs[i])

                            if FLAGS.model_type == 'spec_decomposer_model':
                                ground_truth = decode_spec(targets[i])
                                best_prediction, score = eval_predicted_spec_decomposer_model(
                                    beams, ground_truth, decode_spec)
                                decode_to_str_fn = decode_spec
                            elif FLAGS.model_type == 'synthesizer_model':
                                ground_truth = decode_program_str(targets[i])
                                best_prediction, score = eval_predicted_synthesizer_model(
                                    beams, inps, outs, decode_program)
                                decode_to_str_fn = decode_program_str
                            else:
                                raise ValueError(
                                    f'Unknown model type {FLAGS.model_type}')

                            if score > 0:
                                total_successes += 1
                            total_denominator += 1

                            beams_target = [
                                decode_to_str_fn(beam) for beam in beams
                            ]

                            ios.append(' ; '.join(map(str, zip(inps, outs))))
                            targets_list.append(ground_truth)
                            predictions.append(best_prediction)
                            scores.append(score)
                            logging.info('')
                            logging.info('ios: %s', ios[-1])
                            logging.info('targets[%s]: %s', i, targets[i])
                            logging.info('ground_truth: %s', ground_truth)
                            logging.info('predicted beam: %s',
                                         '\n'.join(beams_target))
                            logging.info('best_prediction: %s',
                                         best_prediction)
                            logging.info('score: %s', score)
                            logging.info('beams: %s', beams)

                            if not ground_truth:
                                logging.warn('ground_truth is empty!')

                            top_of_beam = []
                            for index, beam in enumerate(beams[:-5:-1]):
                                top_of_beam.append(
                                    'index: {}, decoded: {}, tokens: {}'.
                                    format(index, decode_to_str_fn(beam),
                                           beam))
                            top_of_beams.append('\n\n'.join(top_of_beam))

                    all_total_successes, all_total_denominator = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (total_successes, total_denominator)))

                    # Record beam search results as text summaries.
                    message = []
                    for n in np.random.choice(np.arange(len(predictions)), 8):
                        text = (
                            f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n'
                            f'predicted: {predictions[n]}\n\n'
                            f'score: {scores[n]}\n\n'
                            f'top of beam:\n\n{top_of_beams[n]}\n\n')
                        message.append(text)

                    # Write to tensorboard.
                    if jax.host_id() == 0:
                        accuracy = 100 * all_total_successes / all_total_denominator
                        logging.info(
                            '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                            predict_or_test, step, beam_size,
                            all_total_successes, all_total_denominator,
                            accuracy,
                            time.time() - t_inference_start)
                        summary_writer.scalar(
                            '{}/beam-size-{}'.format(predict_or_test,
                                                     beam_size), accuracy,
                            step)

                        summary_writer.text(
                            '{}-samples-beam-{}'.format(
                                predict_or_test, beam_size),
                            '\n------\n'.join(message), step)
                        summary_writer.flush()

        # Save a Checkpoint. Do this at the end of the training loop, so that if a
        # worker is descheduled during a round of prediction (which takes a while),
        # we will redo prediction upon restarting (to avoid losing data).
        if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called
    # 'tokens' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    if raw_datasets["train"] is not None:
        column_names = raw_datasets["train"].column_names
        features = raw_datasets["train"].features
    else:
        column_names = raw_datasets["validation"].column_names
        features = raw_datasets["validation"].features

    if data_args.text_column_name is not None:
        text_column_name = data_args.text_column_name
    elif "tokens" in column_names:
        text_column_name = "tokens"
    else:
        text_column_name = column_names[0]

    if data_args.label_column_name is not None:
        label_column_name = data_args.label_column_name
    elif f"{data_args.task_name}_tags" in column_names:
        label_column_name = f"{data_args.task_name}_tags"
    else:
        label_column_name = column_names[1]

    # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
    # unique labels.
    def get_label_list(labels):
        unique_labels = set()
        for label in labels:
            unique_labels = unique_labels | set(label)
        label_list = list(unique_labels)
        label_list.sort()
        return label_list

    if isinstance(features[label_column_name].feature, ClassLabel):
        label_list = features[label_column_name].feature.names
        # No need to convert the labels since they are already ints.
        label_to_id = {i: i for i in range(len(label_list))}
    else:
        label_list = get_label_list(raw_datasets["train"][label_column_name])
        label_to_id = {l: i for i, l in enumerate(label_list)}
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        label2id=label_to_id,
        id2label={i: l for l, i in label_to_id.items()},
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
    if config.model_type in {"gpt2", "roberta"}:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            add_prefix_space=True,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    model = FlaxAutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Preprocessing the datasets
    # Tokenize all texts and align the labels with them.
    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(
            examples[text_column_name],
            max_length=data_args.max_seq_length,
            padding="max_length",
            truncation=True,
            # We use this argument because the texts in our dataset are lists of words (with a label for each word).
            is_split_into_words=True,
        )

        labels = []

        for i, label in enumerate(examples[label_column_name]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                # ignored in the loss function.
                if word_idx is None:
                    label_ids.append(-100)
                # We set the label for the first token of each word.
                elif word_idx != previous_word_idx:
                    label_ids.append(label_to_id[label[word_idx]])
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
                previous_word_idx = word_idx

            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    processed_raw_datasets = raw_datasets.map(
        tokenize_and_align_labels,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        remove_columns=raw_datasets["train"].column_names,
        desc="Running tokenizer on dataset",
    )

    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    summary_writer = tensorboard.SummaryWriter(training_args.output_dir)
    summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})

    def write_train_metric(summary_writer, train_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    def write_eval_metric(summary_writer, eval_metrics, step):
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args)

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    metric = load_metric("seqeval")

    def get_labels(y_pred, y_true):
        # Transform predictions and references tensos to numpy arrays

        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        return true_predictions, true_labels

    def compute_metrics():
        results = metric.compute()
        if data_args.return_entity_level_metrics:
            # Unpack nested dictionaries
            final_results = {}
            for key, value in results.items():
                if isinstance(value, dict):
                    for n, v in value.items():
                        final_results[f"{key}_{n}"] = v
                else:
                    final_results[key] = value
            return final_results
        else:
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    train_time = 0
    step_per_epoch = len(train_dataset) // train_batch_size
    total_steps = step_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    for epoch in epochs:

        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # train
        for step, batch in enumerate(
            tqdm(
                train_data_collator(input_rng, train_dataset, train_batch_size),
                total=step_per_epoch,
                desc="Training...",
                position=1,
            )
        ):
            state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * step_per_epoch + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:

                eval_metrics = {}
                # evaluate
                for batch in tqdm(
                    eval_data_collator(eval_dataset, eval_batch_size),
                    total=len(eval_dataset) // eval_batch_size,
                    desc="Evaluating ...",
                    position=2,
                ):
                    labels = batch.pop("labels")
                    predictions = p_eval_step(state, batch)
                    predictions = np.array([pred for pred in chain(*predictions)])
                    labels = np.array([label for label in chain(*labels)])
                    labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
                    preds, refs = get_labels(predictions, labels)
                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )

                # evaluate also on leftover examples (not divisible by batch_size)
                num_leftover_samples = len(eval_dataset) % eval_batch_size

                # make sure leftover batch is evaluated on one device
                if num_leftover_samples > 0 and jax.process_index() == 0:
                    # take leftover samples
                    batch = eval_dataset[-num_leftover_samples:]
                    batch = {k: np.array(v) for k, v in batch.items()}

                    labels = batch.pop("labels")
                    predictions = eval_step(unreplicate(state), batch)
                    labels = np.array(labels)
                    labels[np.array(batch["attention_mask"]) == 0] = -100
                    preds, refs = get_labels(predictions, labels)
                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )

                eval_metrics = compute_metrics()

                if data_args.return_entity_level_metrics:
                    logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
                else:
                    logger.info(
                        f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
                    )

                if jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(
                        training_args.output_dir,
                        params=params,
                        push_to_hub=training_args.push_to_hub,
                        commit_message=f"Saving weights and logs of step {cur_step}",
                    )
        epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
Beispiel #28
0
def train(model_def,
          model_dir,
          batch_size,
          num_epochs,
          learning_rate,
          sgd_momentum,
          make_lr_fun=None,
          l2_reg=0.0005,
          run_seed=0):
    """Train model."""
    if jax.host_count() > 1:
        raise ValueError('CIFAR-10 example should not be run on '
                         'more than 1 host (for now)')

    if make_lr_fun is None:
        # No learning rate function provided
        # Default to stepped LR schedule for CIFAR-10 and Wide ResNet
        def make_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
            return lr_schedule.create_stepped_learning_rate_schedule(
                base_lr, steps_per_epoch,
                [[60, 0.2], [120, 0.04], [160, 0.008]])

    summary_writer = tensorboard.SummaryWriter(model_dir)

    rng = random.PRNGKey(run_seed)

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    device_batch_size = batch_size // jax.device_count()

    # Load dataset
    data_source = input_pipeline.CIFAR10DataSource(train_batch_size=batch_size,
                                                   eval_batch_size=batch_size)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds

    # Compute steps per epoch and nb of eval steps
    steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
    steps_per_eval = data_source.EVAL_IMAGES // batch_size
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = learning_rate

    # Create the model
    image_size = 32
    model, state = create_model(rng, device_batch_size, image_size, model_def)
    state = jax_utils.replicate(state)
    optimizer = create_optimizer(model, base_learning_rate, sgd_momentum)
    del model  # don't keep a copy of the initial model

    # Learning rate schedule
    learning_rate_fn = make_lr_fun(base_learning_rate, steps_per_epoch)

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Gather metrics
    train_metrics = []
    epoch = 1
    for step, batch in zip(range(num_steps), train_iter):
        # Generate a PRNG key that will be rolled into the batch
        rng, step_key = jax.random.split(rng)
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)
        # Shard the step PRNG key
        sharded_keys = common_utils.shard_prng_key(step_key)

        # Train step
        optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                                 sharded_keys)
        train_metrics.append(metrics)

        if (step + 1) % steps_per_epoch == 0:
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                tag = 'train_%s' % key
                for i, val in enumerate(vals):
                    summary_writer.scalar(tag, val, step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(optimizer.target, state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info(
                'Epoch %d: TRAIN loss=%.6f, err=%.2f, EVAL loss=%.6f, err=%.2f',
                epoch, train_summary['loss'],
                train_summary['error_rate'] * 100.0, eval_summary['loss'],
                eval_summary['error_rate'] * 100.0)

            summary_writer.scalar('eval_loss', eval_summary['loss'], epoch)
            summary_writer.scalar('eval_error_rate',
                                  eval_summary['error_rate'], epoch)
            summary_writer.flush()

            epoch += 1
Beispiel #29
0
def main(args):
    logdir = os.path.join(args.logdir, args.name)
    logger = logging.setup_logger(logdir)
    logger.info(args)

    logger.info(f'Available devices: {jax.devices()}')

    # Setup input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train')

    ds_train = input_pipeline.get_data(dataset=args.dataset,
                                       mode='train',
                                       repeats=None,
                                       mixup_alpha=args.mixup_alpha,
                                       batch_size=args.batch,
                                       shuffle_buffer=args.shuffle_buffer,
                                       tfds_data_dir=args.tfds_data_dir,
                                       tfds_manual_dir=args.tfds_manual_dir)
    batch = next(iter(ds_train))
    logger.info(ds_train)
    ds_test = input_pipeline.get_data(dataset=args.dataset,
                                      mode='test',
                                      repeats=1,
                                      batch_size=args.batch_eval,
                                      tfds_data_dir=args.tfds_data_dir,
                                      tfds_manual_dir=args.tfds_manual_dir)
    logger.info(ds_test)

    # Build VisionTransformer architecture
    model = models.KNOWN_MODELS[args.model]
    VisionTransformer = model.partial(num_classes=dataset_info['num_classes'])
    _, params = VisionTransformer.init_by_shape(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        [(batch['image'].shape[1:], batch['image'].dtype.name)])

    pretrained_path = os.path.join(args.vit_pretrained_dir,
                                   f'{args.model}.npz')
    params = checkpoint.load_pretrained(
        pretrained_path=pretrained_path,
        init_params=params,
        model_config=models.CONFIGS[args.model],
        logger=logger)

    # pmap replicates the models over all TPUs/GPUs
    vit_fn_repl = jax.pmap(VisionTransformer.call)
    update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps)

    # Create optimizer and replicate it over all TPUs/GPUs
    opt = momentum_clip.Optimizer(
        dtype=args.optim_dtype,
        grad_norm_clip=args.grad_norm_clip).create(params)
    opt_repl = flax_utils.replicate(opt)

    # Delete referenes to the objects that are not needed anymore
    del opt
    del params

    def copyfiles(paths):
        """Small helper to copy files to args.copy_to using tf.io.gfile."""
        if not args.copy_to:
            return
        for path in paths:
            to_path = os.path.join(args.copy_to, args.name,
                                   os.path.basename(path))
            tf.io.gfile.makedirs(os.path.dirname(to_path))
            tf.io.gfile.copy(path, to_path, overwrite=True)
            logger.info(f'Copied {path} to {to_path}.')

    total_steps = args.total_steps or (
        input_pipeline.DATASET_PRESETS[args.dataset]['total_steps'])

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr,
                                                args.decay_type,
                                                args.warmup_steps)
    lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
    update_rngs = jax.random.split(jax.random.PRNGKey(0),
                                   jax.local_device_count())

    # Run training loop
    writer = metric_writers.create_default_writer(logdir, asynchronous=False)
    writer.write_hparams(
        {k: v
         for k, v in vars(args).items() if v is not None})
    logger.info('Starting training loop; initial compile can take a while...')
    t0 = time.time()

    for step, batch, lr_repl in zip(
            range(1, total_steps + 1),
            input_pipeline.prefetch(ds_train, args.prefetch), lr_iter):

        opt_repl, loss_repl, update_rngs = update_fn_repl(
            opt_repl, lr_repl, batch, update_rngs)

        if step == 1:
            logger.info(f'First step took {time.time() - t0:.1f} seconds.')
            t0 = time.time()
        if args.progress_every and step % args.progress_every == 0:
            writer.write_scalars(step, dict(train_loss=float(loss_repl[0])))
            done = step / total_steps
            logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '
                        f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')
            copyfiles(glob.glob(f'{logdir}/*'))

        # Run eval step
        if ((args.eval_every and step % args.eval_every == 0)
                or (step == total_steps)):

            accuracy_test = np.mean([
                c for batch in input_pipeline.prefetch(ds_test, args.prefetch)
                for c in (np.argmax(
                    vit_fn_repl(opt_repl.target, batch['image']), axis=2) ==
                          np.argmax(batch['label'], axis=2)).ravel()
            ])

            lr = float(lr_repl[0])
            logger.info(f'Step: {step} '
                        f'Learning rate: {lr:.7f}, '
                        f'Test accuracy: {accuracy_test:0.5f}')
            writer.write_scalars(step, dict(accuracy_test=accuracy_test,
                                            lr=lr))
            copyfiles(glob.glob(f'{logdir}/*'))

    if args.output:
        checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output)
        logger.info(f'Stored fine tuned checkpoint to {args.output}')
        copyfiles([args.output])
Beispiel #30
0
def train(config, workdir):
    """Runs a training loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    tf.io.gfile.makedirs(workdir)
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)
    rng = jax.random.PRNGKey(config.seed)
    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    if jax.host_id() == 0:
        writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model_name = config.model.name
    ncsn_def = mutils.get_model(model_name).partial(config=config)
    rng, run_rng = jax.random.split(rng)
    # Whether the generative model is conditioned on class labels
    class_conditional = "conditional" in config.training.loss.lower()
    with nn.stateful() as init_model_state:
        with nn.stochastic(run_rng):
            input_shape = (jax.local_device_count(), config.data.image_size,
                           config.data.image_size, 3)
            input_list = [(input_shape, jnp.float32),
                          (input_shape[:1], jnp.int32)]
            if class_conditional:
                input_list.append(input_list[-1])
            _, initial_params = ncsn_def.init_by_shape(model_rng,
                                                       input_list,
                                                       train=True)
            ncsn = nn.Model(ncsn_def, initial_params)

    optimizer = losses.get_optimizer(config).create(ncsn)

    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    del ncsn, init_model_state  # Do not keep a copy of the initial model.

    # Create checkpoints directory and the initial checkpoint
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = utils.Checkpoint(checkpoint_dir, max_to_keep=None)
    ckpt.restore_or_initialize(state)

    # Save intermediate checkpoints to resume training automatically
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
    ckpt_meta = utils.Checkpoint(checkpoint_meta_dir, max_to_keep=1)
    state = ckpt_meta.restore_or_initialize(state)
    initial_step = int(state.step)
    rng = state.rng

    # Build input pipeline.
    rng, ds_rng = jax.random.split(rng)
    train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    scaler = datasets.get_data_scaler(config)  # data normalizer
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Distribute training.
    optimize_fn = losses.optimization_manager(config)
    if config.training.loss.lower() == "ddpm":
        # Use score matching loss with DDPM-type perturbation.
        ddpm_params = mutils.get_ddpm_params()
        train_step = functools.partial(losses.ddpm_loss,
                                       ddpm_params=ddpm_params,
                                       train=True,
                                       optimize_fn=optimize_fn)
        eval_step = functools.partial(losses.ddpm_loss,
                                      ddpm_params=ddpm_params,
                                      train=False)
    else:
        # Use score matching loss with NCSN-type perturbation.
        sigmas = mutils.get_sigmas(config)
        # Whether to use a continuous distribution of noise levels
        continuous = "continuous" in config.training.loss.lower()
        train_step = functools.partial(
            losses.ncsn_loss,
            sigmas=sigmas,
            class_conditional=class_conditional,
            continuous=continuous,
            train=True,
            optimize_fn=optimize_fn,
            anneal_power=config.training.anneal_power)
        eval_step = functools.partial(
            losses.ncsn_loss,
            sigmas=sigmas,
            class_conditional=class_conditional,
            continuous=continuous,
            train=False,
            anneal_power=config.training.anneal_power)

    p_train_step = jax.pmap(train_step, axis_name="batch")
    p_eval_step = jax.pmap(eval_step, axis_name="batch")
    state = flax_utils.replicate(state)

    num_train_steps = config.training.n_iters

    logging.info("Starting training loop at step %d.", initial_step)
    rng = jax.random.fold_in(rng, jax.host_id())
    for step in range(initial_step, num_train_steps + 1):
        # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
        # devices.

        # Convert data to JAX arrays. Use ._numpy() to avoid copy.
        batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access

        rng, *next_rng = jax.random.split(rng,
                                          num=jax.local_device_count() + 1)
        next_rng = jnp.asarray(next_rng)
        loss, state = p_train_step(next_rng, state, batch)
        loss = flax.jax_utils.unreplicate(loss)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)

        if jax.host_id() == 0 and step % 50 == 0:
            logging.info("step: %d, training_loss: %.5e", step, loss)
            writer.scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption.
        if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
        ) == 0:
            saved_state = flax_utils.unreplicate(state)
            saved_state = saved_state.replace(rng=rng)
            ckpt_meta.save(saved_state)

        # Report the loss on an evaluation dataset.
        if step % 100 == 0:
            rng, *next_rng = jax.random.split(rng,
                                              num=jax.local_device_count() + 1)
            next_rng = jnp.asarray(next_rng)
            eval_batch = jax.tree_map(lambda x: scaler(x._numpy()),
                                      next(eval_iter))  # pylint: disable=protected-access
            eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
            eval_loss = flax.jax_utils.unreplicate(eval_loss)
            if jax.host_id() == 0:
                logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
                writer.scalar("eval_loss", eval_loss, step)

        # Save a checkpoint periodically and generate samples.
        if (step + 1
            ) % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            if jax.host_id() == 0:
                saved_state = flax_utils.unreplicate(state)
                saved_state = saved_state.replace(rng=rng)
                ckpt.save(saved_state)

            # Generate and save samples
            if config.training.snapshot_sampling:
                rng, sample_rng = jax.random.split(rng)
                init_shape = tuple(train_ds.element_spec["image"].shape)
                samples = sampling.get_samples(
                    sample_rng,
                    config,
                    flax_utils.unreplicate(state),
                    init_shape,
                    scaler,
                    inverse_scaler,
                    class_conditional=class_conditional)
                this_sample_dir = os.path.join(
                    sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
                tf.io.gfile.makedirs(this_sample_dir)

                if config.sampling.final_only:  # Do not save intermediate samples
                    sample = samples[-1]
                    image_grid = sample.reshape((-1, *sample.shape[2:]))
                    nrow = int(np.sqrt(image_grid.shape[0]))
                    sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
                    with tf.io.gfile.GFile(
                            os.path.join(this_sample_dir, "sample.np"),
                            "wb") as fout:
                        np.save(fout, sample)

                    with tf.io.gfile.GFile(
                            os.path.join(this_sample_dir, "sample.png"),
                            "wb") as fout:
                        utils.save_image(image_grid,
                                         fout,
                                         nrow=nrow,
                                         padding=2)
                else:  # Save all intermediate samples produced during sampling.
                    for i, sample in enumerate(samples):
                        image_grid = sample.reshape((-1, *sample.shape[2:]))
                        nrow = int(np.sqrt(image_grid.shape[0]))
                        sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
                        with tf.io.gfile.GFile(
                                os.path.join(this_sample_dir,
                                             "sample_{}.np".format(i)),
                                "wb") as fout:
                            np.save(fout, sample)

                        with tf.io.gfile.GFile(
                                os.path.join(this_sample_dir,
                                             "sample_{}.png".format(i)),
                                "wb") as fout:
                            utils.save_image(image_grid,
                                             fout,
                                             nrow=nrow,
                                             padding=2)