Beispiel #1
0
 def cond_fn(step_num, prev_ids, *unused_states):
   """Should we run another loop iteration."""
   overflow = mtf.equal(step_num, num_steps)
   has_eos = mtf.reduce_any(
       mtf.equal(prev_ids, eos_id), reduced_dim=length_dim)
   all_has_eos = mtf.reduce_all(has_eos)
   return mtf.logical_not(mtf.logical_or(overflow, all_has_eos))
Beispiel #2
0
def attention_mask_ignore_padding(inputs, dtype=tf.float32):
    """Bias for encoder-decoder attention.

  Args:
    inputs: a mtf.Tensor with shape [..., length_dim]
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [..., memory_length_dim]
  """
    inputs = rename_length_to_memory_length(inputs)
    return mtf.cast(mtf.equal(inputs, 0), dtype) * -1e9
Beispiel #3
0
def masked_local_attention_1d_incremental(x,
                                          prev_k,
                                          prev_v,
                                          step_num,
                                          master_dtype,
                                          slice_dtype,
                                          name=None):
    """Incremental local self-attention (one decode step).

  Incremental version of masked_local_attention_1d()

  Args:
    x: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    prev_v: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    new_v: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
    batch_dims = x.shape.dims[:-1]
    io_channels = x.shape.dims[-1]
    heads, window_length, kv_channels = prev_k.shape.dims[-3:]
    with tf.variable_scope(name, default_name="multihead_attention"):
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            x.mesh, heads, io_channels, kv_channels, master_dtype, slice_dtype,
            x.dtype)
        q = mtf.einsum([x, q_var],
                       mtf.Shape(batch_dims + [heads, kv_channels]))
        k = mtf.einsum([x, k_var],
                       mtf.Shape(batch_dims + [heads, kv_channels]))
        v = mtf.einsum([x, v_var],
                       mtf.Shape(batch_dims + [heads, kv_channels]))
        current_position = mtf.equal(
            mtf.range(x.mesh, window_length, dtype=tf.int32),
            mtf.mod(step_num, window_length.size))
        k = mtf.where(current_position, k, prev_k, output_shape=prev_k.shape)
        v = mtf.where(current_position, v, prev_v, output_shape=prev_v.shape)
        o = dot_product_attention(q, k, v, mask=None)
        y = mtf.einsum([o, o_var], x.shape)
        return y, k, v
Beispiel #4
0
    def body_fn(step_num, ids, *states):
        """Body function for greedy decoding.

    Args:
      step_num: a mtf.Tensor
      ids: a mtf.Tensor
      *states: additional mtf.Tensors
    Returns:
      new_step_num, new_ids, *new_states
    """
        logits, new_states = logits_fn(step_num, ids, states)
        vocab_dim = logits.shape.dims[-1]
        new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature)
        if forced_ids is not None:
            # force the new ids to equal the partial targets where specified
            # (positions where partial_targets contain nonzero values)
            forced = mtf.gather(forced_ids, step_num, length_dim)
            new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0))
        ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32)
        new_step_num = step_num + 1
        return [new_step_num, ids] + new_states
Beispiel #5
0
  def grow_topk(i, alive_seq, alive_log_probs, states=None):
    r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch, beam, length]
      alive_log_probs: probabilities of these sequences. [batch, beam]
      states: optional list of mtf.Tensor
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         list of transformed decoding states)
    """
    logits, new_states = logits_fn(i, alive_seq, states)
    batch_dim, beam_dim, vocab_dim = logits.shape.dims

    # Convert logits to normalized log probs
    candidate_log_probs = mtf.log_softmax(logits, vocab_dim)

    # Multiply the probabilities by the current probabilities of the beam.
    # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
    log_probs = candidate_log_probs + alive_log_probs

    length_penalty = mtf.pow(((5. + mtf.cast(i + 1, logits.dtype)) / 6.), alpha)

    curr_scores = log_probs / length_penalty

    # scores have shape [batch, beam, vocab]
    beam_and_vocab_dim = mtf.Dimension(
        "beam_and_vocab", beam_dim.size * vocab_dim.size)
    flat_shape = mtf.Shape([batch_dim, beam_and_vocab_dim])
    double_beam = mtf.Dimension("double_beam", beam_dim.size * 2)
    # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
    flat_curr_scores = mtf.reshape(curr_scores, flat_shape)

    top_ids, top_scores = mtf.top_k(
        flat_curr_scores, reduced_dim=beam_and_vocab_dim, new_dim=double_beam)

    # Recovering the log probs because we will need to send them back
    top_log_probs = top_scores * length_penalty

    # Work out what beam the top probs are in.
    top_beam_index = top_ids // vocab_dim.size
    top_ids %= vocab_dim.size  # Unflatten the ids

    def my_gather(tensor):
      return mtf.gather(
          tensor, top_beam_index, beam_dim,
          output_shape=mtf.Shape(
              [double_beam if d == beam_dim else d for d in tensor.shape.dims]))

    # Gather up the most probable 2*beams both for the ids and finished_in_alive
    # bools
    top_seq = my_gather(alive_seq)

    if states:
      states = [my_gather(state) for state in new_states]

    # Append the most probable alive
    top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32)
    top_finished = mtf.equal(top_ids, eos_id)

    return top_seq, top_log_probs, top_scores, top_finished, states