Exemple #1
0
def eval_step(params, batch, config, label_smoothing=0.0):
  """Calculate evaluation metrics on a batch."""
  inputs, targets = batch['inputs'], batch['targets']
  weights = jnp.where(targets > 0, 1.0, 0.0)
  logits = models.Transformer(config).apply({'params': params}, inputs, targets)

  return compute_metrics(logits, targets, weights, label_smoothing)
Exemple #2
0
def initialize_cache(inputs, max_decode_len, config):
  """Initialize a cache for a given input shape and max decode length."""
  target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
  initial_variables = models.Transformer(config).init(
      jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype),
      jnp.ones(target_shape, config.dtype))
  return initial_variables['cache']
Exemple #3
0
def eval_per_pos_step(params, batch, config, label_smoothing=0.0):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch['inputs'], batch['targets']
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = models.Transformer(config).apply({'params': params}, inputs,
                                              targets)
    losses = compute_per_pos_loss(logits, targets, weights, label_smoothing)
    length = weights.sum(axis=-1)
    return losses, length
Exemple #4
0
def predict_step_full(inputs,
                      params,
                      cache,
                      eos_id,
                      max_decode_len,
                      config,
                      beam_size=4):
    """Predict translation with fast decoding beam search on a batch."""
    # Prepare transformer fast-decoder call for beam search: for beam search, we
    # need to set up our decoder model to handle a batch size equal to
    # batch_size * beam_size, where each batch item"s data is expanded in-place
    # rather than tiled.
    # i.e. if we denote each batch element subtensor as el[n]:
    # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
    encoded_inputs = decode.flat_batch_beam_expand(
        models.Transformer(config).apply({'params': params},
                                         inputs,
                                         method=models.Transformer.encode),
        beam_size)
    raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.Transformer(config).apply(
            {
                'params': params,
                'cache': flat_cache
            },
            encoded_inputs,
            raw_inputs,  # only needed for input padding mask
            flat_ids,
            mutable=['cache'],
            method=models.Transformer.decode)
        new_flat_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)

        return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    beam_seqs, scores = decode.beam_search(inputs,
                                           cache,
                                           tokens_ids_to_logits,
                                           beam_size=beam_size,
                                           alpha=0.6,
                                           eos_id=eos_id,
                                           max_decode_len=max_decode_len)

    # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
    # sorted in increasing order of log-probability.
    # Return the highest scoring beam sequence, drop first dummy 0 token.
    return beam_seqs, scores
Exemple #5
0
  def loss_fn(params):
    """loss function used for training."""
    logits = models.Transformer(config).apply(
        {'params': params},
        inputs,
        targets,
        inputs_positions=inputs_positions,
        targets_positions=targets_positions,
        inputs_segmentation=inputs_segmentation,
        targets_segmentation=targets_segmentation)

    loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights)
    mean_loss = loss / weight_sum
    return mean_loss
Exemple #6
0
  def tokens_ids_to_logits(flat_ids, flat_cache):
    """Token slice to logits from decoder model."""
    # --> [batch * beam, 1, vocab]
    flat_logits, new_vars = models.Transformer(config).apply(
        {
            'params': params,
            'cache': flat_cache
        },
        encoded_inputs,
        raw_inputs,  # only needed for input padding mask
        flat_ids,
        mutable=['cache'],
        method=models.Transformer.decode)
    new_flat_cache = new_vars['cache']
    # Remove singleton sequence-length dimension:
    # [batch * beam, 1, vocab] --> [batch * beam, vocab]
    flat_logits = flat_logits.squeeze(axis=1)

    return flat_logits, new_flat_cache
Exemple #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.process_index(),
        shard_count=jax.process_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size,
        split_tokenizer=FLAGS.split_tokenizer)

    if FLAGS.aux_eval_dataset:
        aux_datasets = []
        aux_names = FLAGS.aux_eval_dataset.split(',')
        for name in aux_names:
            _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
                dataset_name=name,
                eval_dataset_name=None,
                shard_idx=jax.process_index(),
                shard_count=jax.process_count(),
                data_dir=FLAGS.data_dir,
                vocab_path=vocab_path,
                target_vocab_size=FLAGS.vocab_size,
                batch_size=FLAGS.batch_size,
                max_length=FLAGS.max_target_length,
                max_eval_length=FLAGS.max_eval_target_length,
                paracrawl_size=FLAGS.paracrawl_size,
                is_scores_path=FLAGS.is_scores_path,
                num_to_keep=FLAGS.data_selection_size,
                pseudo_path=FLAGS.pseudo_path,
                repeat_count=FLAGS.repeat_count,
                newscommentary_size=FLAGS.newscommentary_size)
            aux_datasets.append(aux_eval_ds)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)

    flag_key = [
        k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
    ]
    if flag_key:
        flag_key = flag_key[0]
        local_flags = {
            f.name: f.value
            for f in FLAGS.flags_by_module_dict()[flag_key]
        }
        writer.write_hparams(local_flags)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle,
        init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_util.train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(train_util.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_util.initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(train_util.predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info('Starting training loop.')
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=FLAGS.model_dir,
                                     num_profile_steps=5)
        ]
    train_metrics = []
    total_steps = start_step + FLAGS.num_train_steps
    if FLAGS.eval_only:
        total_steps = start_step + 1
    best_eval_loss = 1000
    curr_eval_loss = 1000
    eval_loss_history = []
    last_eval_step = 0
    do_resample_data = False
    gradual_selection_size = FLAGS.data_selection_size
    dynamic_eval_freq = FLAGS.eval_frequency
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, total_steps):
            is_last_step = step == total_steps - 1

            # Resample training data for gradual FT
            if do_resample_data:
                # resample data
                do_resample_data = False
                gradual_selection_size *= .7
                dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)

                train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
                    dataset_name=FLAGS.dataset_name,
                    eval_dataset_name=FLAGS.eval_dataset_name,
                    shard_idx=jax.process_index(),
                    shard_count=jax.process_count(),
                    data_dir=FLAGS.data_dir,
                    vocab_path=vocab_path,
                    target_vocab_size=FLAGS.vocab_size,
                    batch_size=FLAGS.batch_size,
                    max_length=FLAGS.max_target_length,
                    max_eval_length=FLAGS.max_eval_target_length,
                    paracrawl_size=FLAGS.paracrawl_size,
                    is_scores_path=FLAGS.is_scores_path,
                    num_to_keep=int(gradual_selection_size),
                    pseudo_path=FLAGS.pseudo_path,
                    repeat_count=FLAGS.repeat_count,
                    newscommentary_size=FLAGS.newscommentary_size,
                    split_tokenizer=FLAGS.split_tokenizer)
                train_iter = iter(train_ds)

            # Shard data to devices and do a training step.
            if not FLAGS.eval_only:
                logging.info('Doing Training.')
                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    try:
                        batch = common_utils.shard(
                            jax.tree_map(np.asarray, next(train_iter)))
                        optimizer, metrics = p_train_step(
                            optimizer, batch, dropout_rng=dropout_rngs)
                        train_metrics.append(metrics)
                    except StopIteration:
                        is_last_step = True

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:
                if not FLAGS.eval_only:
                    with report_progress.timed('training_metrics'):
                        logging.info('Gathering training metrics.')
                        train_metrics = common_utils.get_metrics(train_metrics)
                        lr = train_metrics.pop('learning_rate').mean()
                        metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                        denominator = metrics_sums.pop('denominator')
                        summary = jax.tree_map(lambda x: x / denominator,
                                               metrics_sums)  # pylint: disable=cell-var-from-loop
                        summary['learning_rate'] = lr
                        summary = {'train_' + k: v for k, v in summary.items()}
                        writer.write_scalars(step, summary)
                        train_metrics = []

                if FLAGS.eval_only:
                    p_eval_per_pos_step = jax.pmap(functools.partial(
                        train_util.eval_per_pos_step, config=eval_config),
                                                   axis_name='batch')
                    # Get per example loss
                    loss_filename = FLAGS.model_dir + '/test_losses.csv'
                    train_util.write_per_example_losses(
                        p_eval_step=p_eval_per_pos_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=FLAGS.num_eval_steps,
                        loss_filename=loss_filename)
                else:
                    with report_progress.timed('eval'):
                        eval_results = train_util.evaluate(
                            p_eval_step=p_eval_step,
                            target=optimizer.target,
                            eval_ds=eval_ds,
                            num_eval_steps=FLAGS.num_eval_steps)
                        curr_eval_loss = eval_results['loss']
                        eval_loss_history.append(curr_eval_loss)
                        if len(eval_loss_history) > 1:
                            improvement_rate = 0.000004
                            orig_loss = eval_loss_history[-2]
                            true_improvement = orig_loss - curr_eval_loss
                            expected_improvement = (
                                step - last_eval_step) * improvement_rate
                            # percent_change = (orig_loss - curr_eval_loss) / orig_loss
                            # percent_change *= 100
                            if true_improvement < expected_improvement:  # percent_change<.1:
                                do_resample_data = True
                        last_eval_step = step
                        writer.write_scalars(
                            step,
                            {'eval_' + k: v
                             for k, v in eval_results.items()})

                if FLAGS.aux_eval_dataset:
                    for aux_i, aux_eval_ds in enumerate(aux_datasets):
                        with report_progress.timed('aux_eval'):
                            eval_results = train_util.evaluate(
                                p_eval_step=p_eval_step,
                                target=optimizer.target,
                                eval_ds=aux_eval_ds,
                                num_eval_steps=FLAGS.num_eval_steps)
                            writer.write_scalars(
                                step, {
                                    'aux' + str(aux_i) + '_eval_' + k: v
                                    for k, v in eval_results.items()
                                })

                if FLAGS.compute_bleu:
                    with report_progress.timed('translate_and_bleu'):
                        decode_file = FLAGS.model_dir + '/decodes.csv'
                        exemplars, bleu_score = train_util.translate_and_calculate_bleu(
                            p_pred_step=p_pred_step,
                            p_init_cache=p_init_cache,
                            target=optimizer.target,
                            predict_ds=predict_ds,
                            decode_tokens=decode_tokens,
                            max_predict_length=FLAGS.max_predict_length,
                            num_eval_steps=FLAGS.num_eval_steps,
                            decode_file=decode_file if FLAGS.eval_only else '')
                        writer.write_scalars(step, {'bleu': bleu_score})
                        writer.write_texts(step, {'samples': exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0
                               or is_last_step)
            if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                if curr_eval_loss < best_eval_loss:  # only save better checkpoints
                    best_eval_loss = curr_eval_loss
                    with report_progress.timed('checkpoint'):
                        checkpoints.save_checkpoint(
                            FLAGS.model_dir,
                            jax_utils.unreplicate(optimizer),
                            step,
                            keep=FLAGS.chkpts_to_keep,
                            overwrite=True)

            if is_last_step:
                break
def compute_is_scores(filename):
    """Compute IS scores for training data."""

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    print('Loading data')
    logging.info('Initializing dataset.')
    train_ds, encoder = input_pipeline.get_wmt_is_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        paracrawl_size=FLAGS.paracrawl_size)
    print('Datasets created')

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
    print('data iterators created')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    eval_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=True,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        # When loading a checkpoint trained with adapters (ie. frozen weights)
        # restoring from the base optimizer fails. We catch this error and create
        # the optimizer with frozen weights.
        try:
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            # Grab last step.
            start_step = int(optimizer.state.step)
        except ValueError:
            adapter = optim.ModelParamTraversal(
                lambda path, _: FLAGS.adapter in path)
            optimizer = optimizer_def.create(optimizer.target, focus=adapter)
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            start_step = optimizer.state[0].step

    else:
        raise RuntimeError('Must restore checkpoint for IS')

    if FLAGS.adapter != NONE and not isinstance(optimizer,
                                                optim.MultiOptimizer):
        adapter = optim.ModelParamTraversal(
            lambda path, _: FLAGS.adapter in path)
        optimizer = optimizer_def.create(optimizer.target, focus=adapter)
    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    p_eval_step = jax.pmap(functools.partial(eval_for_is_step,
                                             config=eval_config),
                           axis_name='batch')

    logging.info('Start scoring loop.')
    metrics_all = []
    t_loop_start = time.time()

    # Eval Metrics
    logging.info('Gathering evaluation metrics.')
    t_eval_start = time.time()
    save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'
    length_fp = tf.io.gfile.GFile(save_file, 'w')
    lengths_writer = csv.writer(length_fp)

    save_file = FLAGS.is_save_path + '/' + filename + '.txt'
    with tf.io.gfile.GFile(save_file, 'w') as fp:
        writer = csv.writer(fp)

        for batch_idx, eval_batch in enumerate(train_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            cur_pred_batch_size = eval_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                eval_batch = jax.tree_map(
                    lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
            eval_batch = common_utils.shard(eval_batch)
            losses, lengths = p_eval_step(optimizer.target, eval_batch)
            if jax.host_id() == 0:
                losses = common.tohost(losses)
                lengths = common.tohost(lengths)
                if cur_pred_batch_size % n_devices:
                    writer.writerow(losses[:cur_pred_batch_size])
                    lengths_writer.writerow(lengths[:cur_pred_batch_size])
                else:
                    writer.writerow(losses)
                    lengths_writer.writerow(lengths)

            if batch_idx % 500 == 0:
                print('Batch', batch_idx)
                print(time.time() - t_loop_start)
    length_fp.close()
def setup():
  """Compute IS scores for training data."""

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  if FLAGS.jax_backend_target:
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if FLAGS.batch_size % n_devices:
    raise ValueError('Batch size must be divisible by the number of devices')

  vocab_path = FLAGS.vocab_path
  if vocab_path is None:
    raise RuntimeError('Vocab path must be provided')

  # Load Dataset
  print('Loading data')
  logging.info('Initializing dataset.')
  _, (_, encoder_tgt) = input_pipeline.get_wmt_is_datasets(
      n_devices=n_devices,
      dataset_name=FLAGS.dataset_name,
      shard_idx=jax.process_index(),
      shard_count=jax.process_count(),
      data_dir=FLAGS.data_dir,
      vocab_path=vocab_path,
      target_vocab_size=FLAGS.vocab_size,
      batch_size=FLAGS.batch_size,
      max_length=FLAGS.max_target_length,
      paracrawl_size=FLAGS.paracrawl_size,
      split_tokenizer=FLAGS.split_tokenizer)
  print('Datasets created')

  encoder = encoder_tgt
  vocab_size = int(encoder.vocab_size())

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  print('data iterators created')

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  eval_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=FLAGS.share_embeddings,
      logits_via_embedding=FLAGS.logits_via_embedding,
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
      emb_dim=FLAGS.emb_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.qkv_dim,
      mlp_dim=FLAGS.mlp_dim,
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      deterministic=True,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  predict_config = eval_config.replace(deterministic=True, decode=True)
  rng = jax.random.PRNGKey(FLAGS.random_seed)
  rng, init_rng = jax.random.split(rng)
  # It's possible that is supposed to be per device batch size
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      FLAGS.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  p_eval_step = jax.pmap(
      functools.partial(
          eval_for_is_step,
          config=eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=256,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=4),
      axis_name='batch',
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  return p_eval_step, optimizer, p_init_cache, p_pred_step, decode_tokens