def map_data_for_transformer_fn(x, y): """Maps data for training, and handles weried behaviors for different vers.""" # Will transform input x and targets y into tuple(x, y) as new model inputs. if misc.is_v2(): # For TF v2, the 2nd parameter is omitted to make Keras training work. return ((x, y), ) else: # For TF v1, Keras requires a dummy placeholder as the 2nd parameter. return ((x, y), tf.constant(0.0))
def sequence_beam_search(symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, alpha, max_decode_length, eos_id, padded_decode=False, dtype="float32"): """Search for sequence of subtoken ids with the largest probability. Args: symbols_to_logits_fn: A function that takes in ids, index, and cache as arguments. The passed in arguments will have shape: ids -> A tensor with shape [batch_size * beam_size, index]. index -> A scalar. cache -> A nested dictionary of tensors [batch_size * beam_size, ...]. The function must return a tuple of logits and new cache: logits -> A tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested dictionary with the same shape/structure as the inputted cache. initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each batch item. initial_cache: A dictionary, containing starting decoder variables information. vocab_size: An integer, the size of tokens. beam_size: An integer, the number of beams. alpha: A float, defining the strength of length normalization. max_decode_length: An integer, the maximum length to decoded a sequence. eos_id: An integer, ID of eos token, used to determine when a sequence has finished. padded_decode: A bool, indicating if max_sequence_length padding is used for beam search. dtype: A tensorflow data type used for score computation. The default is tf.float32. Returns: Top decoded sequences [batch_size, beam_size, max_decode_length] sequence scores [batch_size, beam_size] """ batch_size = (initial_ids.shape.as_list()[0] if padded_decode else tf.shape(initial_ids)[0]) if misc.is_v2(): sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id, padded_decode, dtype) else: sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id, padded_decode, dtype) return sbs.search(initial_ids, initial_cache)
def sequence_beam_search(symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, alpha, max_decode_length, eos_id, dtype="float32"): """Search for sequence of subtoken ids with the largest probability. Args: symbols_to_logits_fn: A function that takes in ids, index, and cache as arguments. The passed in arguments will have shape: ids -> [batch_size * beam_size, index] index -> [] (scalar) cache -> nested dictionary of tensors [batch_size * beam_size, ...] The function must return logits and new cache. logits -> [batch * beam_size, vocab_size] new cache -> same shape/structure as inputted cache initial_ids: Starting ids for each batch item. int32 tensor with shape [batch_size] initial_cache: dict containing starting decoder variables information vocab_size: int size of tokens beam_size: int number of beams alpha: float defining the strength of length normalization max_decode_length: maximum length to decoded sequence eos_id: int id of eos token, used to determine when a sequence has finished, dtype: The dtype to use. Returns: Top decoded sequences [batch_size, beam_size, max_decode_length] sequence scores [batch_size, beam_size] """ batch_size = tf.shape(initial_ids)[0] if misc.is_v2(): sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id, dtype) else: sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id, dtype) return sbs.search(initial_ids, initial_cache)