Exemple #1
0
def save_unreplicated_checkpoint_background(train_dir,
                                            optimizer_state,
                                            params,
                                            batch_stats,
                                            training_metrics_state,
                                            global_step,
                                            preemption_count,
                                            sum_train_cost,
                                            max_to_keep=1):
    """Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
    logging.info('Saving checkpoint to ckpt_%d', global_step)
    unreplicated_optimizer_state = jax.device_get(
        jax_utils.unreplicate(optimizer_state))
    unreplicated_params = jax.device_get(jax_utils.unreplicate(params))
    unreplicated_batch_stats = jax.device_get(
        jax_utils.unreplicate(batch_stats))
    unreplicated_training_metrics_state = jax.device_get(
        jax_utils.unreplicate(training_metrics_state))
    state = dict(global_step=global_step,
                 preemption_count=preemption_count,
                 sum_train_cost=sum_train_cost,
                 optimizer_state=unreplicated_optimizer_state,
                 params=unreplicated_params,
                 batch_stats=unreplicated_batch_stats,
                 training_metrics_grabber=unreplicated_training_metrics_state)
    save_checkpoint_background(train_dir,
                               global_step,
                               state,
                               max_to_keep=max_to_keep)
    logging.info('Done saving checkpoint.')
 def write_summaries(self, example, logits, summary_writer, info, step,
                     state):
     example = jax_utils.unreplicate(example)
     outputs = jnp.argmax(jax_utils.unreplicate(logits), axis=-1)
     text = summary_utils.human_readable_texts(example, outputs, info)
     summary_writer.text('predictions', '<pre>{}</pre>'.format(text), step)
     self.generate_plots(state, summary_writer, step)
Exemple #3
0
def maybe_copy_model_from_pretraining(optimizer, pretrain_optimizer, step,
                                      adam_opt_def):
  """Copy model parameters from pretraining."""
  if step < FLAGS.num_pretrain_steps:
    optimizer = jax_utils.unreplicate(optimizer)
    state_dict = adam_opt_def.state_dict(
        target=jax_utils.unreplicate(pretrain_optimizer).target,
        state=optim.OptimizerState(jnp.asarray(step, dtype=jnp.int32),
                                   optimizer.state.param_states))
    optimizer = optimizer.restore_state(state_dict)
    optimizer = jax_utils.replicate(optimizer)
  return optimizer
Exemple #4
0
 def load_checkpoint(self, ckpt_dir):
     """Loads optimizer from ckpt_dir."""
     target = jax_utils.unreplicate(self._optimizer)
     optimizer = checkpoints.restore_checkpoint(ckpt_dir, target=target)
     if optimizer is target:
         raise ValueError('Unable to load checkpoint from %s' % ckpt_dir)
     self.set_weights(optimizer)
Exemple #5
0
    def test_train_one_step(self):
        """Tests training loop over one step."""
        iterator = self._dataset.get_train()
        batch = next(iterator)

        state = jax_utils.replicate(self._state)
        optimizer = jax_utils.replicate(self._optimizer.create(self._model))

        self._rng, step_key = jax.random.split(self._rng)
        batch = training._shard_batch(batch)
        sharded_keys = common_utils.shard_prng_key(step_key)

        p_train_step = jax.pmap(functools.partial(
            training.train_step, learning_rate_fn=self._learning_rate_fn),
                                axis_name='batch')
        _, _, loss, gradient_norm = p_train_step(optimizer, batch,
                                                 sharded_keys, state)

        loss = jnp.mean(loss)
        gradient_norm = jax_utils.unreplicate(gradient_norm)

        with self.subTest(name='test_loss_range'):
            self.assertBetween(loss, self._min_loss, self._max_loss)

        with self.subTest(name='test_gradient_norm'):
            self.assertGreaterEqual(gradient_norm, 0)
Exemple #6
0
def maybe_log_training_metrics(metrics_state, metrics_summary_fn,
                               metrics_logger):
    """If appropriate, send a summary tree of training metrics to the logger."""
    if metrics_state:
        unreplicated_metrics_state = jax_utils.unreplicate(metrics_state)
        summary_tree = metrics_summary_fn(unreplicated_metrics_state)
        metrics_logger.append_pytree(summary_tree)
Exemple #7
0
def write_images(H, ema_params, viz_batch_original, viz_batch_processed, fname,
                 logprint):
    rng = random.PRNGKey(H.seed_sample)
    ema_apply = partial(
        VAE(H).apply, {'params': jax_utils.unreplicate(ema_params)})
    forward_get_latents = partial(ema_apply, method=VAE(H).forward_get_latents)
    forward_samples_set_latents = partial(
        ema_apply, method=VAE(H).forward_samples_set_latents)
    forward_uncond_samples = partial(ema_apply,
                                     method=VAE(H).forward_uncond_samples)

    zs = [s['z'] for s in forward_get_latents(viz_batch_processed, rng)]
    batches = [viz_batch_original.numpy()]
    mb = viz_batch_processed.shape[0]
    lv_points = np.floor(
        np.linspace(0, 1, H.num_variables_visualize + 2) *
        len(zs)).astype(int)[1:-1]
    for i in lv_points:
        batches.append(forward_samples_set_latents(mb, zs[:i], rng, t=0.1))
    for t in [1.0, 0.9, 0.8, 0.7][:H.num_temperatures_visualize]:
        batches.append(forward_uncond_samples(mb, rng, t=t))
    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape(
        (n_rows, mb,
         *viz_batch_processed.shape[1:])).transpose([0, 2, 1, 3, 4]).reshape([
             n_rows * viz_batch_processed.shape[1],
             mb * viz_batch_processed.shape[2], 3
         ])
    logprint(f'printing samples to {fname}')
    Image.fromarray(im).save(fname)
Exemple #8
0
def save_model(path, optimizer, ema, epoch, H):
    optimizer, ema = jax_utils.unreplicate((optimizer, ema))
    checkpoints.save_checkpoint(path, (optimizer, epoch), optimizer.state.step)
    checkpoints.save_checkpoint(path + '_ema', ema, optimizer.state.step)
    from_log = os.path.join(H.save_dir, 'log.jsonl')
    to_log = f'{os.path.dirname(path)}/{os.path.basename(path)}-log.jsonl'
    subprocess.check_output(['cp', from_log, to_log])
Exemple #9
0
    def run_eval(self, flax_module, batch_stats, optimizer_state, global_step):
        """Computes the loss hessian and returns the max eigenvalue.

    Note, the full lanczos tridiagonal matrix is saved via the logger to
    train_dir/checkpoints/config['name'].

    Args:
      flax_module: Replicated flax module.
      batch_stats: Replicated batch_stats from the trainer.
      optimizer_state: Replicated optimizer state from the trainer.
      global_step: Current training step.

    Returns:
      Max eigenvalue of the loss (full tridiag is saved to disk).
    """
        del batch_stats
        if self.callback_config.get('precondition'):
            precondition_config = self.callback_config.get(
                'precondition_config', default=FrozenConfigDict())
            diag_preconditioner = precondition.make_diag_preconditioner(
                self.hps.optimizer, self.hps.opt_hparams,
                jax_utils.unreplicate(optimizer_state), precondition_config)
        else:
            diag_preconditioner = None
        hessian_metrics, _, _ = self.hessian_evaluator.evaluate_spectrum(
            flax_module, global_step, diag_preconditioner=diag_preconditioner)
        if jax.host_id() == 0:
            self.logger.append_pytree(hessian_metrics)

        max_eig_key = self.name + '/max_eig'
        return {max_eig_key: hessian_metrics['max_eig_hess']}
Exemple #10
0
 def save_checkpoint(self, ckpt_dir):
     """Saves unreplicated optimizer to ckpt_dir."""
     optimizer = jax_utils.unreplicate(self._optimizer)
     checkpoints.save_checkpoint(
         ckpt_dir,
         target=optimizer,
         step=self._train_step,
     )
Exemple #11
0
 def evaluate(self, state, dataset):
     dataloader = get_batched_dataset(dataset, self.args.batch_size)
     total = len(dataset) // self.args.batch_size
     running_loss = jnp.array(0, dtype=jnp.float32)
     i = 0
     for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
         batch = self.data_collator(batch)
         metrics = self.val_step_fn(state, **batch)
         running_loss += jax_utils.unreplicate(metrics["loss"])
         i += 1
     return running_loss / i
Exemple #12
0
 def save_checkpoint(self, save_dir, state):
     state = jax_utils.unreplicate(state)
     print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
     self.model_save_fn(save_dir, params=state.params)
     with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
         f.write(to_bytes(state.opt_state))
     joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
     joblib.dump(self.data_collator,
                 os.path.join(save_dir, "data_collator.joblib"))
     with open(os.path.join(save_dir, "training_state.json"), "w") as f:
         json.dump({"step": state.step.item()}, f)
     print("DONE")
Exemple #13
0
def update_params(
        workload: spec.Workload,
        current_param_container: spec.ParameterContainer,
        current_params_types: spec.ParameterTypeTree,
        model_state: spec.ModelAuxiliaryState,
        hyperparameters: spec.Hyperparamters,
        input_batch: spec.Tensor,
        label_batch: spec.Tensor,
        # This will define the output activation via `output_activation_fn`.
        loss_type: spec.LossType,
        optimizer_state: spec.OptimizerState,
        eval_results: List[Tuple[int, float]],
        global_step: int,
        rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params, updated_model_state)."""
    del current_params_types
    del loss_type
    del eval_results
    del global_step

    num_devices = jax.local_device_count()
    input_shape = input_batch.shape
    reshaped_input_batch = jnp.reshape(
        input_batch,
        (num_devices, input_shape[0] // num_devices, *input_shape[1:]))
    reshaped_label_batch = jnp.reshape(label_batch,
                                       (num_devices, label_batch.shape[0] //
                                        num_devices, *label_batch.shape[1:]))

    # TODO(znado) we should be more efficient than replicating state each step.
    new_optimizer_state, updated_params, new_model_state = pmapped_update_params(
        workload, jax_utils.replicate(current_param_container),
        jax_utils.replicate(model_state), hyperparameters,
        reshaped_input_batch, reshaped_label_batch,
        jax_utils.replicate(optimizer_state), rng, jnp.arange(num_devices))
    return (jax_utils.unreplicate(new_optimizer_state),
            jax_utils.unreplicate(updated_params),
            jax_utils.unreplicate(new_model_state))
Exemple #14
0
        def maybe_restore_checkpoint(params, train_dir,
                                     external_checkpoint_path):
            """Helper function to replicate_and_maybe_restore a checkpoint."""

            (_, ret_params, _, _, ret_global_step, ret_sum_train_cost,
             ret_preemption_count, ret_is_restored
             ) = checkpoint.replicate_and_maybe_restore_checkpoint(
                 {}, params, {}, {}, train_dir, external_checkpoint_path)

            ret_params_unrep = jax.device_get(
                jax_utils.unreplicate(ret_params))

            return (ret_params_unrep, ret_global_step, ret_sum_train_cost,
                    ret_preemption_count, ret_is_restored)
Exemple #15
0
def evaluate(model, state, eval_ds, num_eval_steps=-1):
    """Evaluate the model on the given dataset."""
    logging.info("Starting evaluation.")
    eval_metrics = None
    with StepTraceContextHelper("eval", 0) as trace_context:
        for step, batch in enumerate(eval_ds):  # pytype: disable=wrong-arg-types
            batch = jax.tree_map(np.asarray, batch)
            metrics_update = flax_utils.unreplicate(
                eval_step(model, state, batch))
            eval_metrics = (metrics_update if eval_metrics is None else
                            eval_metrics.merge(metrics_update))
            if num_eval_steps > 0 and step + 1 == num_eval_steps:
                break
            trace_context.next_step()
    return eval_metrics
Exemple #16
0
    def train(self, state, tr_dataset, val_dataset):
        args = self.args
        total = len(tr_dataset) // args.batch_size

        rng = jax.random.PRNGKey(0)
        drp_rng = jax.random.split(rng, jax.device_count())
        for epoch in range(args.max_epochs):
            running_loss = jnp.array(0, dtype=jnp.float32)
            tr_dataloader = get_batched_dataset(tr_dataset,
                                                args.batch_size,
                                                seed=epoch)
            i = 0
            for batch in tqdm(tr_dataloader,
                              total=total,
                              desc=f"Running EPOCH-{epoch}"):
                batch = self.data_collator(batch)
                state, metrics, drp_rng = self.train_step_fn(
                    state, drp_rng, **batch)
                running_loss += jax_utils.unreplicate(metrics["loss"])
                i += 1
                if i % args.logging_steps == 0:
                    state_step = jax_utils.unreplicate(state.step)
                    tr_loss = running_loss.item() / i
                    lr = self.scheduler_fn(state_step - 1)

                    eval_loss = self.evaluate(state, val_dataset)
                    logging_dict = dict(step=state_step.item(),
                                        eval_loss=eval_loss.item(),
                                        tr_loss=tr_loss,
                                        lr=lr.item())
                    tqdm.write(str(logging_dict))
                    self.logger.log(logging_dict, commit=True)

                if i % args.save_steps == 0:
                    self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}",
                                         state=state)
Exemple #17
0
    def maybe_reset_train_state(self):
        optimizer = jax_utils.unreplicate(self.train_state.optimizer)

        if self.hparams.get('reinitilize_params_at_each_step', False):
            del optimizer.target
            (flax_model, _, _) = pipeline_utils.create_flax_module(
                optimizer.target.module,
                self.task.dataset.meta_data['input_shape'], self.hparams,
                nn.make_rng(),
                self.task.dataset.meta_data.get('input_dtype', jnp.float32))
        else:
            flax_model = optimizer.target

        # Reset optimizer
        if self.hparams.get('reinitialize_optimizer_at_each_step', False):
            self.lr_start_step = jax_utils.unreplicate(
                self.train_state.global_step)
            optimizer = optimizers.get_optimizer(
                self.hparams).create(flax_model)
        else:
            optimizer = optimizer.replace(target=flax_model)

        optimizer = jax_utils.replicate(optimizer)
        self.train_state = self.train_state.replace(optimizer=optimizer)
Exemple #18
0
    def evaluate(self, workdir, dir_name='eval', ckpt_name=None):
        """Perform one evaluation."""
        checkpoint_dir = os.path.join(workdir, 'checkpoints-0')
        ckpt = checkpoint.Checkpoint(checkpoint_dir)
        state_dict = ckpt.restore_dict(os.path.join(checkpoint_dir, ckpt_name))
        ema_params = flax.core.FrozenDict(state_dict['ema_params'])
        step = int(state_dict['step'])

        # Distribute training.
        ema_params = flax_utils.replicate(ema_params)

        eval_logdir = os.path.join(workdir, dir_name)
        tf.io.gfile.makedirs(eval_logdir)
        writer = metric_writers.create_default_writer(
            eval_logdir, just_logging=jax.process_index() > 0)

        outputs = self._eval_epoch(params=ema_params)
        outputs = flax_utils.unreplicate(outputs)
        scalars, images = outputs['scalars'], outputs['images']
        writer.write_scalars(step, scalars)
        writer.write_images(step, images)
def save_checkpoint(experiment_dir, train_state, keep=3):
    """Saves a checkpoint.

  First syncs the model state across replicas, then it unreplicates it by taking
  the train state of the first replica and saves it as a checkpoint.

  Args:
    experiment_dir: str; Experiment directory for saving the checkpoint.
    train_state: Dataclass; An instance of TrainState that holds the state of
      training.
    keep: int; Number of checkpoints to keep.
  """
    if jax.host_id() == 0:
        # get train state from the first replica
        checkpoint_state = jax.device_get(jax_utils.unreplicate(train_state))
        ckpt_path = checkpoint_path(experiment_dir,
                                    int(checkpoint_state.global_step))
        if not tf.io.gfile.exists(ckpt_path):
            checkpoints.save_checkpoint(experiment_dir,
                                        checkpoint_state,
                                        int(checkpoint_state.global_step),
                                        keep=keep)
Exemple #20
0
def save_checkpoint(
    model,
    save_dir,
    state,
    cur_step: int,
    with_opt: bool = True,
    push_to_hub: bool = False,
):
    state = jax_utils.unreplicate(state)
    if with_opt:
        logger.info(f"Saving optimizer and training state in {save_dir}...")
        with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
            f.write(to_bytes(state.opt_state))
        with open(os.path.join(save_dir, "training_state.json"), "w") as f:
            json.dump({"step": state.step.item()}, f)
    logger.info(
        f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}'
    )
    model.save_pretrained(
        save_dir,
        params=state.params,
        push_to_hub=push_to_hub,
        commit_message=f"Saving weights and logs of step {cur_step}",
    )
  def get_env_aligned_pairs_idx(self, env_reps, env_batches, env_ids):
    """Computes alignments between all environment pairs.

    Args:
      env_reps: jnp array; Reps for different environments (sharded).
      env_batches: list of dict; Batches of different environments (sharded).
      env_ids: jnp array; Environment ids.

    Returns:
      alignment between batches of environment pairs (sharded).
    """
    # TODO(riannevdberg, samiraabnar): aligning is done on the total
    #  unsharded batch, but that requires access between local batches
    #  when computing the loss. Unsure why this works! To be compatible
    #  with random alignment and sinkhorn soft alignment we should do
    #  alignment only within local batches.
    env_reps = shard_util.unshard_env_batch(env_reps)
    env_batches = shard_util.unshard(env_batches)
    with nn.stochastic(jax_utils.unreplicate(self.train_state.rng)):
      alignments = self.task.get_env_aligned_pairs_idx(env_reps, env_batches,
                                                       env_ids)
    alignments = dataset_utils.shard(alignments)

    return alignments
Exemple #22
0
def run_train(run_configuration):
    """Runs the training workflow."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    log_dir = os.path.join(run_dir, 'train')
    checkpoint_path = run_configuration.original_checkpoint_path

    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)

    summary_freq = config.logging.summary_freq
    metrics_all = []
    tick = time.time()
    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rngs, logits, state = train_step(
            optimizer, train_inputs, dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(
                    checkpoint_dir, jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if summary_freq and step % summary_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train step: %d, loss: %.4f', step, summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = summary_freq / (tock - tick)
                examples_per_sec = denominator / (tock - tick)
                tick = tock
                summary_writer.scalar('per-second/steps', steps_per_sec, step)
                summary_writer.scalar('per-second/examples', examples_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar(key, val, step)

                adapter.write_summaries(example, logits, summary_writer, info,
                                        step, state)

                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []
Exemple #23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = "tpu_driver"
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_token = 2  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_token, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_t5_mlm",
                           model_args,
                           data_args,
                           framework="flax")

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        level=logging.INFO,
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.config_name:
        config = T5Config.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            vocab_size=len(tokenizer),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = T5Config.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
    # Since we make sure that all sequences are of the same length, no attention_mask is needed.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name],
                         return_attention_mask=False)

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
    # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
    # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
    expanded_inputs_length, targets_length = compute_input_and_target_lengths(
        inputs_length=max_seq_length,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= expanded_inputs_length:
            total_length = (total_length //
                            expanded_inputs_length) * expanded_inputs_length
        # Split by chunks of max_len.
        result = {
            k: [
                t[i:i + expanded_inputs_length]
                for i in range(0, total_length, expanded_inputs_length)
            ]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxT5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config.vocab_size = len(tokenizer)
        model = FlaxT5ForConditionalGeneration(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
        input_length=max_seq_length,
        target_length=targets_length,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    num_of_hosts = jax.process_count()
    current_host_idx = jax.process_index()

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set([
            layer[-2:] for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        ])
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in layer_norm_named_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        # Avoid using jax.numpy here in case of TPU training
        train_samples_idx = np.random.permutation(np.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            local_host_model_inputs = {
                key: np.split(model_inputs.data[key], num_of_hosts,
                              axis=0)[current_host_idx]
                for key, value in model_inputs.data.items()
            }

            # Model forward
            model_inputs = shard(local_host_model_inputs)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
                    f" {train_metric['learning_rate'].mean()})")

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                # Avoid using jax.numpy here in case of TPU training
                eval_samples_idx = np.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size,
                                                       drop_last=False)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples)

                    # Model forward
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params,
                        model_inputs.data,
                        min_device_batch=per_device_eval_batch_size)
                    eval_metrics.append(metrics)

                # get eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                # Update progress bar
                epochs.write(
                    f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
                )

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        # Avoid using jax.numpy here in case of TPU training
        eval_samples_idx = np.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size,
                                               drop_last=False)

        eval_metrics = []
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params,
                model_inputs.data,
                min_device_batch=per_device_eval_batch_size)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(),
                                    eval_metrics)

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Exemple #25
0
def train(config, workdir):
  """Runs a training loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

  # Create directories for experimental logs
  tf.io.gfile.makedirs(workdir)
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  tb_dir = os.path.join(workdir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  if jax.host_id() == 0:
    writer = tensorboard.SummaryWriter(tb_dir)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  # Whether the generative model is conditioned on class labels
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = (jax.local_device_count(), config.data.image_size,
                     config.data.image_size, 3)
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)

  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  # Create checkpoints directory and the initial checkpoint
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  ckpt = utils.Checkpoint(
      checkpoint_dir,
      max_to_keep=None)
  ckpt.restore_or_initialize(state)

  # Save intermediate checkpoints to resume training automatically
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  ckpt_meta = utils.Checkpoint(
      checkpoint_meta_dir,
      max_to_keep=1)
  state = ckpt_meta.restore_or_initialize(state)
  initial_step = int(state.step)
  rng = state.rng

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
  scaler = datasets.get_data_scaler(config)  # data normalizer
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Distribute training.
  optimize_fn = losses.optimization_manager(config)
  if config.training.loss.lower() == "ddpm":
    # Use score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                   train=True, optimize_fn=optimize_fn)
    eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                  train=False)
  else:
    # Use score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    # Whether to use a continuous distribution of noise levels
    continuous = "continuous" in config.training.loss.lower()
    train_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=True,
        optimize_fn=optimize_fn,
        anneal_power=config.training.anneal_power)
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=False,
        anneal_power=config.training.anneal_power)

  p_train_step = jax.pmap(train_step, axis_name="batch")
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
  state = flax_utils.replicate(state)

  num_train_steps = config.training.n_iters

  logging.info("Starting training loop at step %d.", initial_step)
  rng = jax.random.fold_in(rng, jax.host_id())
  for step in range(initial_step, num_train_steps + 1):
    # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
    # devices.

    # Convert data to JAX arrays. Use ._numpy() to avoid copy.
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access

    rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    next_rng = jnp.asarray(next_rng)
    loss, state = p_train_step(next_rng, state, batch)
    loss = flax.jax_utils.unreplicate(loss)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    if jax.host_id() == 0 and step % 50 == 0:
      logging.info("step: %d, training_loss: %.5e", step, loss)
      writer.scalar("training_loss", loss, step)

    # Save a temporary checkpoint to resume training after pre-emption.
    if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
    ) == 0:
      saved_state = flax_utils.unreplicate(state)
      saved_state = saved_state.replace(rng=rng)
      ckpt_meta.save(saved_state)

    # Report the loss on an evaluation dataset.
    if step % 100 == 0:
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      if jax.host_id() == 0:
        logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
        writer.scalar("eval_loss", eval_loss, step)

    # Save a checkpoint periodically and generate samples.
    if (step +
        1) % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint.
      if jax.host_id() == 0:
        saved_state = flax_utils.unreplicate(state)
        saved_state = saved_state.replace(rng=rng)
        ckpt.save(saved_state)

      # Generate and save samples
      if config.training.snapshot_sampling:
        rng, sample_rng = jax.random.split(rng)
        init_shape = tuple(train_ds.element_spec["image"].shape)
        samples = sampling.get_samples(sample_rng,
                                       config,
                                       flax_utils.unreplicate(state),
                                       init_shape,
                                       scaler,
                                       inverse_scaler,
                                       class_conditional=class_conditional)
        this_sample_dir = os.path.join(
            sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
        tf.io.gfile.makedirs(this_sample_dir)

        if config.sampling.final_only:  # Do not save intermediate samples
          sample = samples[-1]
          image_grid = sample.reshape((-1, *sample.shape[2:]))
          nrow = int(np.sqrt(image_grid.shape[0]))
          sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.np"), "wb") as fout:
            np.save(fout, sample)

          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
            utils.save_image(image_grid, fout, nrow=nrow, padding=2)
        else:  # Save all intermediate samples produced during sampling.
          for i, sample in enumerate(samples):
            image_grid = sample.reshape((-1, *sample.shape[2:]))
            nrow = int(np.sqrt(image_grid.shape[0]))
            sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.np".format(i)),
                "wb") as fout:
              np.save(fout, sample)

            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.png".format(i)),
                "wb") as fout:
              utils.save_image(image_grid, fout, nrow=nrow, padding=2)
Exemple #26
0
        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
            model_inputs = data_collator(samples, pad_to_multiple_of=16)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                eval_samples_idx = jnp.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
Exemple #27
0
def main():
    args = parse_args()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Handle the repository creation
    if args.push_to_hub:
        if args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
        else:
            repo_name = args.hub_model_id
        repo = Repository(args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", args.task_name)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)

    # Preprocessing the datasets
    if args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and args.task_name is not None
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif args.task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
    )

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(args.output_dir)
            summary_writer.hparams(vars(args))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(args.num_train_epochs)
    rng = jax.random.PRNGKey(args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
    )

    state = create_train_state(
        model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
    )

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    if args.task_name is not None:
        metric = load_metric("glue", args.task_name)
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    for epoch in range(1, num_epochs + 1):
        logger.info(f"Epoch {epoch}")
        logger.info("  Training...")

        train_start = time.time()
        train_metrics = []
        rng, input_rng = jax.random.split(rng)

        # train
        for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
            state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(metrics)
        train_time += time.time() - train_start
        logger.info(f"    Done! Training metrics: {unreplicate(metrics)}")

        logger.info("  Evaluating...")

        # evaluate
        for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            labels = batch.pop("labels")
            predictions = eval_step(unreplicate(state), batch)
            metric.add_batch(predictions=predictions, references=labels)

        eval_metric = metric.compute()
        logger.info(f"    Done! Eval metrics: {eval_metric}")

        cur_step = epoch * (len(train_dataset) // train_batch_size)

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            write_metric(train_metrics, eval_metric, train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(args.output_dir, params=params)
            tokenizer.save_pretrained(args.output_dir)
            if args.push_to_hub:
                repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

    # save the eval metrics in json
    if jax.process_index() == 0:
        eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
        path = os.path.join(args.output_dir, "eval_results.json")
        with open(path, "w") as f:
            json.dump(eval_metric, f, indent=4, sort_keys=True)
Exemple #28
0
def train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps,
               num_train_steps, optimizer, state, p_eval_step, p_train_step,
               start_step, train_iter, summary_writer):
    """Training loop.

  Args:
    config: experiment config.
    dropout_rngs: float array; Jax PRNG key.
    eval_ds: tf.dataset; Evaluation dataset.
    eval_freq: int; Evaluation frequency;
    num_eval_steps: int; Number of evaluation steps.
    num_train_steps: int; Number of training steps.
    optimizer: flax optimizer.
    state: model state, e.g. batch statistics.
    p_eval_step: fn; Pmapped evaluation step function.
    p_train_step: fn; Pmapped train step function.
    start_step: int; global training step.
    train_iter: iter(tf.dataset); Training data iterator.
    summary_writer: tensorflow summary writer.

  Returns:
    optimizer, global training step
  """
    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    step = 0
    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, state, metrics, dropout_rngs = p_train_step(
            optimizer, state, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            (jax_utils.unreplicate(optimizer),
                                             jax_utils.unreplicate(state)),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('examples_per_second',
                                      steps_per_sec * config.batch_size, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'val_{key}', val, step)
                summary_writer.flush()
    return optimizer, state, step
Exemple #29
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    # BOS special attention only makes sense if we are using relative attention
    # and it's not the baseline.
    if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention
                                        or FLAGS.attention_mask_type
                                        == 'baseline'):
        raise ValueError(
            "bos_special_attention doesn't work when use_relative_attention={} and "
            'attention_mask_type={}'.format(FLAGS.use_relative_attention,
                                            FLAGS.attention_mask_type))

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    hparam_str_dict = json.loads(FLAGS.xm_parameters)
    hparam_str = ','.join([
        '%s=%s' % (shorten(k), str(hparam_str_dict[k]))
        for k in hparam_str_dict.keys()
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    predict_io_shape = (FLAGS.per_device_batch_size,
                        FLAGS.num_strings_per_task,
                        FLAGS.predict_max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    # Parse io and program token sequences (for eval).
    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
        return inps, outs

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        program = program[program != bos_token]

        try:
            return dsl.decode_program(program.tolist(), id_token_table)
        except:  # pylint: disable=bare-except
            return None  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
    padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:])
    logging.info('padded_shapes: %s', padded_shapes)
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=padded_shapes,
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:],
                             program_shape[1:])
    logging.info('predict_padded_shapes: %s', predict_padded_shapes)
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)
    train_ds = dataset.skip(FLAGS.num_eval_steps)
    if FLAGS.train_set_batches > 0:
        train_ds = train_ds.take(FLAGS.train_set_batches)
    train_ds = train_ds.repeat()

    test_dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.test_dataset_filepattern, token_id_table, char_id_table)
    test_dataset = test_dataset.padded_batch(
        batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False)
    quick_test_dataset = (test_dataset.take(
        FLAGS.num_quick_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))
    final_test_dataset = (test_dataset.take(
        FLAGS.num_final_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    default_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size, output_vocab_size=program_vocab_size)
    base_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token,
        num_input_relative_position_buckets=FLAGS.num_position_buckets,
        max_input_distance=min(FLAGS.max_distance,
                               default_config.max_input_distance),
        num_output_relative_position_buckets=FLAGS.num_position_buckets,
        max_output_distance=min(FLAGS.max_distance,
                                default_config.max_output_distance),
        num_input_cross_output_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_input_cross_output_distance=min(
            FLAGS.max_distance,
            default_config.max_input_cross_output_distance),
        num_program_relative_position_buckets=FLAGS.num_position_buckets,
        max_program_distance=min(FLAGS.max_distance,
                                 default_config.max_program_distance),
        num_program_cross_embed_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_program_cross_embed_distance=min(
            FLAGS.max_distance,
            default_config.max_program_cross_embed_distance),
        bidirectional_program_attention=FLAGS.bidirectional_program_attention)
    train_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config,
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    eval_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(deterministic=True),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    predict_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(
            shift=False,
            deterministic=True,
            decode=not FLAGS.slow_decode,
            max_len=max(FLAGS.max_characters, FLAGS.max_program_length,
                        FLAGS.predict_max_characters)),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    dropout_rng = jax.random.split(rng, jax.local_device_count())
    del rng

    m = models.DecomposeAttentionTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)
        if FLAGS.finetune_start_step > 0:
            logging.info(
                'Checking that start_step (%s) == finetune_start_step (%s)',
                start_step, FLAGS.finetune_start_step)
            assert start_step >= FLAGS.finetune_start_step
            steps_to_skip = start_step - FLAGS.finetune_start_step
        else:
            steps_to_skip = start_step

        # TODO(kshi): It is likely that this code can lead to the job stalling for
        # 10+ hours when restarting from a checkpoint that had been trained a long
        # time, possibly because dataset skipping is slow.
        logging.info('Skipping %s steps...', steps_to_skip)
        train_ds = train_ds.skip(steps_to_skip)
        dummy_p_train_step = jax.pmap(
            lambda dropout_rng: jax.random.split(dropout_rng)[1])
        for _ in range(steps_to_skip):
            dropout_rng = dummy_p_train_step(dropout_rng)
        logging.info('Finished skipping steps')
        logging.info('Host %s has dropout_rng = %s', jax.host_id(),
                     dropout_rng)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    # TODO(jxihong): Implement fast decoding.
    assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

    if FLAGS.finetune_start_step <= 0:
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr)
    else:
        # Constant LR for finetuning.
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr, factors='constant')
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eos_token=eos_token,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        eos_token=eos_token,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        slow_decode=FLAGS.slow_decode),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, ))

    # Main Train Loop
    # ---------------------------------------------------------------------------

    logging.info('Starting training!')
    metrics_all = []
    tick = time.time()
    train_iter = train_ds.as_numpy_iterator()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, dropout_rng = p_train_step(optimizer,
                                                       inputs,
                                                       outputs,
                                                       programs,
                                                       dropout_rng=dropout_rng)
        metrics_all.append(metrics)
        is_last_step = step == FLAGS.num_train_steps - 1

        # Save a Checkpoint
        if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.

        # Training Metrics
        if (step and step % FLAGS.log_freq == 0) or is_last_step:
            logging.info('Gathering training metrics.')
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums)
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)

            if jax.host_id() == 0:
                logging.info('Train in step: %d, loss: %.4f', step,
                             summary['loss'])
                tock = time.time()
                steps_per_sec = FLAGS.log_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('train/steps per second', steps_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar('train/' + key, val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

        # Evaluation Metrics
        if (step and step % FLAGS.eval_freq == 0) or is_last_step:
            logging.info('Gathering evaluation metrics.')
            t_evaluation_start = time.time()
            eval_metrics = []
            for batches in eval_ds.as_numpy_iterator():
                inputs, outputs, programs = common_utils.shard(batches)

                metrics = p_eval_step(optimizer.target, inputs, outputs,
                                      programs)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            if jax.host_id() == 0:
                logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                             time.time() - t_evaluation_start, step,
                             eval_summary['loss'])
                for key, val in eval_summary.items():
                    summary_writer.scalar('eval/' + key, val, step)
                summary_writer.flush()

        # Beam search metrics.
        if (step and step % FLAGS.predict_freq == 0) or is_last_step:
            logging.info('Gathering beam search metrics.')
            test_ds = final_test_dataset if is_last_step else quick_test_dataset

            for dataset, predict_or_test in [(predict_ds, 'predict'),
                                             (test_ds, 'test')]:

                for beam_size in [1, 10]:
                    t_inference_start = time.time()
                    total_successes = 0
                    total_denominator = 0
                    pred_successes = collections.defaultdict(int)
                    pred_denominators = collections.defaultdict(int)

                    ios, targets, predictions, top_of_beams = [], [], [], []
                    for batches in dataset.as_numpy_iterator():
                        pred_batch = batches
                        # Handle final odd-sized batch by padding instead of dropping it.
                        cur_pred_batch_size = pred_batch[0].shape[0]
                        if cur_pred_batch_size % n_devices:
                            padded_size = int(
                                np.ceil(cur_pred_batch_size / n_devices) *
                                n_devices)
                            # pylint: disable=cell-var-from-loop
                            pred_batch = jax.tree_map(
                                lambda x: pad_examples(x, padded_size),
                                pred_batch)
                        inputs, outputs, programs = common_utils.shard(
                            pred_batch)

                        cache = (p_init_cache(inputs, outputs, programs)
                                 if not FLAGS.slow_decode else None)
                        predicted = p_pred_step(optimizer.target, inputs,
                                                outputs, cache, beam_size)
                        predicted = tohost(predicted)
                        inputs, outputs, programs = map(
                            tohost, (inputs, outputs, programs))

                        for i, beams in enumerate(predicted):
                            inps, outs = decode_io(inputs[i], outputs[i])
                            p, p_score = eval_predicted(
                                beams,
                                inps,
                                outs,
                                parse_beam_fn=decode_program)

                            # Split by length of program.
                            program = programs[i]
                            num_expressions = len(
                                decode_program(program).expressions)
                            pred_denominators[num_expressions] += 1
                            total_denominator += 1
                            if p_score >= len(inps):
                                pred_successes[num_expressions] += 1
                                total_successes += 1

                            ios.append(' ; '.join(map(str, zip(inps, outs))))
                            targets.append(
                                decode_program(programs[i]).to_string())
                            try:
                                predictions.append(p.to_string())
                            except:  # pylint: disable=bare-except
                                predictions.append('Did not compile')
                            logging.info('ios: %s', ios[-1])
                            logging.info('target: %s', targets[-1])
                            beams_log = []
                            for beam in beams:
                                try:
                                    beams_log.append(
                                        decode_program(beam).to_string())
                                except:  # pylint: disable=bare-except
                                    beams_log.append('Did not compile')
                            logging.info('predicted beam: %s',
                                         '\n'.join(beams_log))

                            top_of_beam = []
                            for index, beam in enumerate(beams[:-5:-1]):
                                try:
                                    decoded_program = decode_program(
                                        beam).to_string()
                                except:  # pylint: disable=bare-except
                                    decoded_program = 'Did not compile'
                                top_of_beam.append(
                                    'index: {}, decoded: {}, tokens: {}'.
                                    format(index, decoded_program, beam))
                            top_of_beams.append('\n\n'.join(top_of_beam))

                    all_total_successes, all_total_denominator = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (total_successes, total_denominator)))
                    all_pred_successes, all_pred_denominators = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (pred_successes, pred_denominators)))

                    # Record beam search results as text summaries.
                    message = []
                    for n in np.random.choice(np.arange(len(predictions)), 8):
                        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                                f'predicted: {predictions[n]}\n\n'
                                f'top of beam:\n\n{top_of_beams[n]}\n\n')
                        message.append(text)

                    # Write to tensorboard.
                    if jax.host_id() == 0:
                        accuracy = 100 * all_total_successes / all_total_denominator
                        logging.info(
                            '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                            predict_or_test, step, beam_size,
                            all_total_successes, all_total_denominator,
                            accuracy,
                            time.time() - t_inference_start)
                        summary_writer.scalar(
                            '{}/beam-size-{}'.format(predict_or_test,
                                                     beam_size), accuracy,
                            step)

                        for length in sorted(all_pred_successes.keys()):
                            this_length_accuracy = (
                                100 * all_pred_successes[length] /
                                all_pred_denominators[length])
                            logging.info(
                                '  accuracy for length %s: %s / %s = %.2f%%',
                                length, all_pred_successes[length],
                                all_pred_denominators[length],
                                this_length_accuracy)
                            summary_writer.scalar(
                                '{}-by-length/beam-size-{}-length-{}'.format(
                                    predict_or_test, beam_size, length),
                                this_length_accuracy, step)

                        summary_writer.text(
                            '{}-samples-beam-{}'.format(
                                predict_or_test, beam_size),
                            '\n------\n'.join(message), step)
                        summary_writer.flush()
Exemple #30
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          decode=True)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.ProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_lib.train_step,
        learning_rate_fn=learning_rate_fn,
        config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_lib.initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(train_lib.predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, 5, 6))

    # Main Train Loop
    # ---------------------------------------------------------------------------
    train_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, train_rngs = p_train_step(optimizer,
                                                      inputs,
                                                      outputs,
                                                      programs,
                                                      train_rng=train_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f', step,
                         summary['loss'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions = [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    # pylint: disable=cell-var-from-loop
                    pred_batch = jax.tree_map(
                        lambda x: train_lib.pad_examples(x, padded_size),
                        pred_batch)
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache = p_init_cache(inputs, outputs, programs)
                predicted = p_pred_step(optimizer.target, inputs, outputs,
                                        cache, eos_token, programs.shape[-1],
                                        beam_size)
                predicted = train_lib.tohost(predicted)
                inputs, outputs, programs = map(train_lib.tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_score = train_lib.eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')

            all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()