Exemplo n.º 1
0
def predict_step(inputs,
                 params,
                 eos_id,
                 max_decode_len,
                 config,
                 beam_size=4,
                 return_entire_beam=False):
    """Predict translation with fast decoding beam search on a batch."""
    # Prepare zeroed-out autoregressive cache.
    target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
    cache = models.Transformer(config).init(
        jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype),
        jnp.ones(target_shape, config.dtype))['cache']
    # Prepare transformer fast-decoder call for beam search: for beam search, we
    # need to set up our decoder model to handle a batch size equal to
    # batch_size * beam_size, where each batch item's data is expanded in-place
    # rather than tiled.
    # i.e. if we denote each batch element subtensor as el[n]:
    # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
    encoded_inputs = decode.flat_batch_beam_expand(
        models.Transformer(config).apply({'params': params},
                                         inputs,
                                         method=models.Transformer.encode),
        beam_size)
    raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, vocab]
        flat_logits, new_vars = models.Transformer(config).apply(
            {
                'params': params,
                'cache': flat_cache
            },
            encoded_inputs,
            raw_inputs,  # only needed for input padding mask
            flat_ids,
            mutable=['cache'],
            method=models.Transformer.decode)
        new_flat_cache = new_vars['cache']
        return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    beam_seqs, _ = decode.beam_search(inputs,
                                      cache,
                                      tokens_ids_to_logits,
                                      beam_size=beam_size,
                                      alpha=0.6,
                                      eos_id=eos_id,
                                      max_decode_len=max_decode_len)

    # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
    # sorted in increasing order of log-probability.
    # Return the highest scoring beam sequence, drop first dummy 0 token.
    if return_entire_beam:
        return beam_seqs[:, :, 1:]
    else:
        return beam_seqs[:, -1, 1:]
Exemplo n.º 2
0
def eval_step(params, batch, config, label_smoothing=0.0):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch['inputs'], batch['targets']
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = models.Transformer(config).apply({'params': params}, inputs,
                                              targets)

    return compute_metrics(logits, targets, weights, label_smoothing)
Exemplo n.º 3
0
def get_initial_params(rng, config, transformer_config, optimizer_def):
    """Get the initial parameter tree."""
    input_shape = (config.batch_size, CFG.max_input_length)
    target_shape = (config.batch_size, CFG.max_target_length)
    initial_variables = models.Transformer(transformer_config).init(
        rng, jnp.ones(input_shape, jnp.float32),
        jnp.ones(target_shape, jnp.float32))
    # apply an optimizer to the parameters
    return optimizer_def.create(initial_variables['params'])
Exemplo n.º 4
0
 def tokens_ids_to_logits(flat_ids, flat_cache):
     """Token slice to logits from decoder model."""
     # --> [batch * beam, vocab]
     flat_logits, new_vars = models.Transformer(config).apply(
         {
             'params': params,
             'cache': flat_cache
         },
         encoded_inputs,
         raw_inputs,  # only needed for input padding mask
         flat_ids,
         mutable=['cache'],
         method=models.Transformer.decode)
     new_flat_cache = new_vars['cache']
     return flat_logits, new_flat_cache
Exemplo n.º 5
0
    def loss_fn(params):
        """loss function used for training."""
        logits = models.Transformer(config).apply(
            {'params': params},
            inputs,
            targets,
            inputs_positions=inputs_positions,
            targets_positions=targets_positions,
            inputs_segmentation=inputs_segmentation,
            targets_segmentation=targets_segmentation,
            rngs={'dropout': dropout_rng})

        loss, weight_sum = compute_weighted_cross_entropy(
            logits, targets, weights, label_smoothing, z_loss)
        mean_loss = loss / weight_sum
        return mean_loss, logits
Exemplo n.º 6
0
 def loss_fn(params, batch, dropout_rng):
     """loss function used for training."""
     (inputs, targets, inputs_positions, targets_positions,
      inputs_segmentation, targets_segmentation) = batch
     logits = models.Transformer(config).apply(
         {'params': params},
         inputs,
         targets,
         inputs_positions=inputs_positions,
         targets_positions=targets_positions,
         inputs_segmentation=inputs_segmentation,
         targets_segmentation=targets_segmentation,
         rngs={'dropout': dropout_rng})
     weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
     loss, weight_sum = compute_weighted_cross_entropy(
         logits, targets, weights, label_smoothing, z_loss)
     # mean_loss = loss / weight_sum
     metrics = compute_metrics(logits, targets, weights)
     return loss, (weight_sum, metrics)