Exemplo n.º 1
0
def eval_step(params, batch, config, label_smoothing=0.0):
    """Calculate evaluation metrics on a batch."""
    inputs = batch["inputs"]
    weights = jnp.where(inputs > 0, 1.0, 0.0)
    logits = models.TransformerLM(config).apply({"params": params}, inputs)

    return compute_metrics(logits, inputs, weights, label_smoothing)
Exemplo n.º 2
0
def predict_step(inputs, params, rngkey, eos_id, max_decode_len, config,
                 temperature, top_k):
    """Predict language model on a batch."""
    target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
    initial_variables = models.TransformerLM(config).init(
        jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype))
    cache = initial_variables["cache"]

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.TransformerLM(config).apply(
            {
                "params": params,
                "cache": flat_cache
            },
            flat_ids,
            mutable=["cache"])
        new_flat_cache = new_vars["cache"]
        # Remove singleton sequence-length dimension:
        # [batch, 1, vocab] --> [batch, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    seqs = temperature_sampler.temperature_sample(inputs,
                                                  cache,
                                                  tokens_ids_to_logits,
                                                  rngkey,
                                                  temperature=temperature,
                                                  topk=top_k,
                                                  eos_token=eos_id)

    return seqs
Exemplo n.º 3
0
    def loss_fn(params):
        """loss function used for training."""
        logits = models.TransformerLM(config).apply(
            {"params": params},
            inputs,
            inputs_positions=inputs_positions,
            inputs_segmentation=inputs_segmentation,
            rngs={"dropout": dropout_rng})

        loss, weight_sum = compute_weighted_cross_entropy(
            logits, inputs, weights, label_smoothing)
        mean_loss = loss / weight_sum
        return mean_loss, logits
Exemplo n.º 4
0
 def tokens_ids_to_logits(flat_ids, flat_cache):
     """Token slice to logits from decoder model."""
     # --> [batch * beam, 1, vocab]
     flat_logits, new_vars = models.TransformerLM(config).apply(
         {
             "params": params,
             "cache": flat_cache
         },
         flat_ids,
         mutable=["cache"])
     new_flat_cache = new_vars["cache"]
     # Remove singleton sequence-length dimension:
     # [batch, 1, vocab] --> [batch, vocab]
     flat_logits = flat_logits.squeeze(axis=1)
     return flat_logits, new_flat_cache
Exemplo n.º 5
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    def encode_strings(strs, max_len):
        tokenized_batch = np.zeros((len(strs), max_len), np.int32)
        for i, s in enumerate(strs):
            toks = encoder.tokenize(s).numpy()
            # Remove EOS token in prompt.
            tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1]
        return tokenized_batch

    tokenized_prompts = encode_strings([config.prompts],
                                       config.max_predict_length)

    logging.info("Initializing model, optimizer, and step functions.")
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    rng, inference_rng = random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.TransformerLM(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32))

    learning_rate_fn = create_learning_rate_schedule(
        learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

    optimizer = optax.adamw(learning_rate_fn,
                            b1=0.9,
                            b2=0.98,
                            eps=1e-9,
                            weight_decay=config.weight_decay)
    state = train_state.TrainState.create(apply_fn=m.apply,
                                          params=initial_variables["params"],
                                          tx=optimizer)
    # We access model params only from optimizer below.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        state = checkpoints.restore_checkpoint(workdir, state)
        # Grab last step.
        start_step = int(state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    state = jax_utils.replicate(state)

    # compile multidevice versions of train/eval/predict step fn.
    p_train_step = jax.pmap(functools.partial(
        train_step, config=train_config, learning_rate_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name="batch")

    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          temperature=config.sampling_temperature,
                          top_k=config.sampling_top_k),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                state, metrics = p_train_step(state,
                                              batch,
                                              dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]),
                                                     a_max=1.0e4)
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        params=state.params,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    # (clipped) perplexity after averaging log-perplexitie
                    eval_results["perplexity"] = jnp.clip(jnp.exp(
                        eval_results["loss"]),
                                                          a_max=1.0e4)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("generate_text"):
                    exemplars = generate_prediction(
                        p_pred_step=p_pred_step,
                        params=state.params,
                        tokenized_prompts=tokenized_prompts,
                        eos_id=eos_id,
                        inference_rng=inference_rng,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(workdir,
                                                jax_utils.unreplicate(state),
                                                step)