Пример #1
0
def write_per_example_losses(*, p_eval_step, target, eval_ds, num_eval_steps,
                             loss_filename):
    """Evaluate the target an return a dictionary with the metrics."""
    logging.info('Gathering evaluation metrics.')
    losses = []
    lengths = []
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    for _, eval_batch in zip(range(num_eval_steps), eval_iter):
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        eval_batch = common_utils.shard(eval_batch)
        loss, length = p_eval_step(target, eval_batch)
        losses.append(common.tohost(loss))
        lengths.append(common.tohost(length))
    # Write losses and lengths
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(loss_filename, 'w') as f:
            writer = csv.writer(f)
            for pos_losses in losses:
                for val in pos_losses:
                    writer.writerow(list(val))
        with tf.io.gfile.GFile(loss_filename.replace('.csv', '_length.csv'),
                               'w') as f:
            writer = csv.writer(f)
            for val in lengths:
                writer.writerow([int(v) for v in list(val)])
    return
def get_losses(ds_iter, optimizer, p_eval_step, model_id, test_cluster_id):
  """Given optimizer and dataset, compute losses and write to file."""
  logging.info('Start scoring loop.')
  n_devices = jax.local_device_count()
  t_loop_start = time.time()

  filename = '/losses_testcluster{test_cluster_id}_ftid{model_id}.csv'
  save_file = filename.format(test_cluster_id=test_cluster_id,
                              model_id=model_id)
  save_file = FLAGS.save_path + save_file
  with tf.io.gfile.GFile(save_file, 'w') as fp:
    writer = csv.writer(fp)

    for batch_idx, eval_batch in enumerate(ds_iter):
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
      cur_pred_batch_size = eval_batch['inputs'].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        eval_batch = jax.tree_map(
            lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
      eval_batch = common_utils.shard(eval_batch)
      losses, lengths = p_eval_step(optimizer.target, eval_batch)
      if jax.process_index() == 0:
        losses = common.tohost(losses)
        lengths = common.tohost(lengths)
        if cur_pred_batch_size % n_devices:
          writer.writerow(losses[:cur_pred_batch_size])
        else:
          writer.writerow(losses)

      if batch_idx % 500 == 0:
        print('Batch', batch_idx)
        print(time.time() - t_loop_start)

      if batch_idx >= FLAGS.limit:
        break
def get_decodes(ds_iter, optimizer, p_init_cache, p_pred_step, model_id,
                test_cluster_id, decode_tokens):
  """Given optimizer and dataset, compute losses and write to file."""
  logging.info('Start scoring loop.')
  n_devices = jax.local_device_count()
  predictions = []
  max_predict_length = 256

  filename = '/decodes_testcluster{test_cluster_id}_ftid{model_id}.csv'
  save_file = filename.format(test_cluster_id=test_cluster_id,
                              model_id=model_id)
  save_file = FLAGS.save_path + save_file
  with tf.io.gfile.GFile(save_file, 'w') as fp:
    writer = csv.writer(fp)

    for batch_idx, pred_batch in enumerate(ds_iter):
      pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
      # Handle final odd-sized batch by padding instead of dropping it.
      cur_pred_batch_size = pred_batch['inputs'].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        pred_batch = jax.tree_map(
            lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
            pred_batch)
      pred_batch = common_utils.shard(pred_batch)

      cache = p_init_cache(pred_batch['inputs'])
      predicted, _ = p_pred_step(pred_batch['inputs'], optimizer.target, cache,
                                 decode.EOS_ID, max_predict_length)
      if jax.process_index() == 0:
        predicted = common.tohost(predicted)
        # Iterate through non-padding examples of batch.
        for s in predicted[:cur_pred_batch_size]:
          predictions.append(decode_tokens(s))

      if batch_idx >= FLAGS.limit:
        break
    writer.writerow(predictions)
Пример #4
0
def compute_is_scores(filename):
    """Compute IS scores for training data."""

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    print('Loading data')
    logging.info('Initializing dataset.')
    train_ds, encoder = input_pipeline.get_wmt_is_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        paracrawl_size=FLAGS.paracrawl_size)
    print('Datasets created')

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
    print('data iterators created')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    eval_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=True,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        # When loading a checkpoint trained with adapters (ie. frozen weights)
        # restoring from the base optimizer fails. We catch this error and create
        # the optimizer with frozen weights.
        try:
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            # Grab last step.
            start_step = int(optimizer.state.step)
        except ValueError:
            adapter = optim.ModelParamTraversal(
                lambda path, _: FLAGS.adapter in path)
            optimizer = optimizer_def.create(optimizer.target, focus=adapter)
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            start_step = optimizer.state[0].step

    else:
        raise RuntimeError('Must restore checkpoint for IS')

    if FLAGS.adapter != NONE and not isinstance(optimizer,
                                                optim.MultiOptimizer):
        adapter = optim.ModelParamTraversal(
            lambda path, _: FLAGS.adapter in path)
        optimizer = optimizer_def.create(optimizer.target, focus=adapter)
    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    p_eval_step = jax.pmap(functools.partial(eval_for_is_step,
                                             config=eval_config),
                           axis_name='batch')

    logging.info('Start scoring loop.')
    metrics_all = []
    t_loop_start = time.time()

    # Eval Metrics
    logging.info('Gathering evaluation metrics.')
    t_eval_start = time.time()
    save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'
    length_fp = tf.io.gfile.GFile(save_file, 'w')
    lengths_writer = csv.writer(length_fp)

    save_file = FLAGS.is_save_path + '/' + filename + '.txt'
    with tf.io.gfile.GFile(save_file, 'w') as fp:
        writer = csv.writer(fp)

        for batch_idx, eval_batch in enumerate(train_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            cur_pred_batch_size = eval_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                eval_batch = jax.tree_map(
                    lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
            eval_batch = common_utils.shard(eval_batch)
            losses, lengths = p_eval_step(optimizer.target, eval_batch)
            if jax.host_id() == 0:
                losses = common.tohost(losses)
                lengths = common.tohost(lengths)
                if cur_pred_batch_size % n_devices:
                    writer.writerow(losses[:cur_pred_batch_size])
                    lengths_writer.writerow(lengths[:cur_pred_batch_size])
                else:
                    writer.writerow(losses)
                    lengths_writer.writerow(lengths)

            if batch_idx % 500 == 0:
                print('Batch', batch_idx)
                print(time.time() - t_loop_start)
    length_fp.close()