def map_data_for_transformer_fn(x, y):
    """Maps data for training, and handles weried behaviors for different vers."""
    # Will transform input x and targets y into tuple(x, y) as new model inputs.
    if misc.is_v2():
        # For TF v2, the 2nd parameter is omitted to make Keras training work.
        return ((x, y), )
    else:
        # For TF v1, Keras requires a dummy placeholder as the 2nd parameter.
        return ((x, y), tf.constant(0.0))
Beispiel #2
0
def sequence_beam_search(symbols_to_logits_fn,
                         initial_ids,
                         initial_cache,
                         vocab_size,
                         beam_size,
                         alpha,
                         max_decode_length,
                         eos_id,
                         padded_decode=False,
                         dtype="float32"):
    """Search for sequence of subtoken ids with the largest probability.

  Args:
    symbols_to_logits_fn: A function that takes in ids, index, and cache as
      arguments. The passed in arguments will have shape:
        ids -> A tensor with shape [batch_size * beam_size, index].
        index -> A scalar.
        cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
      The function must return a tuple of logits and new cache:
        logits -> A tensor with shape [batch * beam_size, vocab_size].
        new cache -> A nested dictionary with the same shape/structure as the
          inputted cache.
    initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
      each batch item.
    initial_cache: A dictionary, containing starting decoder variables
      information.
    vocab_size: An integer, the size of tokens.
    beam_size: An integer, the number of beams.
    alpha: A float, defining the strength of length normalization.
    max_decode_length: An integer, the maximum length to decoded a sequence.
    eos_id: An integer, ID of eos token, used to determine when a sequence has
      finished.
    padded_decode: A bool, indicating if max_sequence_length padding is used
      for beam search.
    dtype: A tensorflow data type used for score computation. The default is
      tf.float32.

  Returns:
    Top decoded sequences [batch_size, beam_size, max_decode_length]
    sequence scores [batch_size, beam_size]
  """
    batch_size = (initial_ids.shape.as_list()[0]
                  if padded_decode else tf.shape(initial_ids)[0])
    if misc.is_v2():
        sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size,
                                   batch_size, beam_size, alpha,
                                   max_decode_length, eos_id, padded_decode,
                                   dtype)
    else:
        sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size,
                                    batch_size, beam_size, alpha,
                                    max_decode_length, eos_id, padded_decode,
                                    dtype)
    return sbs.search(initial_ids, initial_cache)
Beispiel #3
0
def sequence_beam_search(symbols_to_logits_fn,
                         initial_ids,
                         initial_cache,
                         vocab_size,
                         beam_size,
                         alpha,
                         max_decode_length,
                         eos_id,
                         dtype="float32"):
    """Search for sequence of subtoken ids with the largest probability.

  Args:
    symbols_to_logits_fn: A function that takes in ids, index, and cache as
      arguments. The passed in arguments will have shape:
        ids -> [batch_size * beam_size, index]
        index -> [] (scalar)
        cache -> nested dictionary of tensors [batch_size * beam_size, ...]
      The function must return logits and new cache.
        logits -> [batch * beam_size, vocab_size]
        new cache -> same shape/structure as inputted cache
    initial_ids: Starting ids for each batch item.
      int32 tensor with shape [batch_size]
    initial_cache: dict containing starting decoder variables information
    vocab_size: int size of tokens
    beam_size: int number of beams
    alpha: float defining the strength of length normalization
    max_decode_length: maximum length to decoded sequence
    eos_id: int id of eos token, used to determine when a sequence has finished,
    dtype: The dtype to use.

  Returns:
    Top decoded sequences [batch_size, beam_size, max_decode_length]
    sequence scores [batch_size, beam_size]
  """
    batch_size = tf.shape(initial_ids)[0]
    if misc.is_v2():
        sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size,
                                   batch_size, beam_size, alpha,
                                   max_decode_length, eos_id, dtype)
    else:
        sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size,
                                    batch_size, beam_size, alpha,
                                    max_decode_length, eos_id, dtype)
    return sbs.search(initial_ids, initial_cache)