Exemple #1
0
def _sampling_step(time, func, state, min_length, max_length, pad_id, eos_id):
    # Compute log probabilities
    seqs = state.inputs
    seq_probs = state.seq_probs
    # [batch_size * num_samples, vocab_size]
    step_log_probs, next_state = func(seqs, state.state)

    # Suppress <eos> if needed
    batch_size = tf.shape(step_log_probs)[0]
    vocab_size = step_log_probs.shape[-1].value or tf.shape(step_log_probs)[1]
    add_mask = tf.one_hot(eos_id,
                          vocab_size,
                          dtype=step_log_probs.dtype,
                          on_value=step_log_probs.dtype.min,
                          off_value=0.0)
    add_mask = tf.tile(tf.reshape(add_mask, [1, -1]), [batch_size, 1])
    add_mask = tf.where(time < min_length, add_mask, tf.zeros_like(add_mask))
    step_log_probs = step_log_probs + add_mask

    # sample from distribution
    symbol_indices = tf.multinomial(step_log_probs, 1, output_dtype=tf.int32)
    symbol_scores = tf.squeeze(utils.gather_2d(step_log_probs, symbol_indices))
    curr_flags = tf.squeeze(tf.equal(symbol_indices, eos_id), axis=1)
    curr_flags = tf.logical_or(state.flags, curr_flags)

    # Append <pad> to finished samples
    symbol_indices = tf.where(state.flags, tf.fill([batch_size, 1], pad_id),
                              symbol_indices)
    symbol_scores = tf.where(state.flags, tf.zeros([batch_size]),
                             symbol_scores)

    # Force sampler to generate <eos> if length exceed max_length
    eos_flags = tf.where(time > max_length, tf.ones([batch_size], tf.bool),
                         tf.zeros([batch_size], tf.bool))
    eos_scores = tf.squeeze(
        utils.gather_2d(step_log_probs, tf.fill([batch_size, 1], eos_id)))
    eos_indices = tf.fill([batch_size, 1], eos_id)
    cond = tf.logical_and(tf.logical_not(curr_flags), eos_flags)
    curr_flags = tf.logical_or(curr_flags, eos_flags)
    symbol_indices = tf.where(cond, eos_indices, symbol_indices)
    symbol_scores = tf.where(cond, eos_scores, symbol_scores)
    step_symbol_scores = tf.expand_dims(symbol_scores, -1)

    new_state = SamplerState(inputs=tf.concat([seqs, symbol_indices], axis=1),
                             state=next_state,
                             scores=state.scores + symbol_scores,
                             flags=curr_flags,
                             seq_probs=tf.concat(
                                 [seq_probs, step_symbol_scores], axis=1))

    return time + 1, new_state
def _beam_search_step(time, func, state, batch_size, beam_size, alpha, pad_id,
                      eos_id):
    # Compute log probabilities
    seqs, log_probs = state.inputs[:2]
    flat_seqs = utils.merge_first_two_dims(seqs)
    flat_state = nest.map_structure(lambda x: utils.merge_first_two_dims(x),
                                    state.state)
    step_log_probs, next_state = func(flat_seqs, flat_state)
    step_log_probs = utils.split_first_two_dims(step_log_probs, batch_size,
                                                beam_size)
    next_state = nest.map_structure(
        lambda x: utils.split_first_two_dims(x, batch_size, beam_size),
        next_state)
    curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs

    # Apply length penalty
    length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha)
    curr_scores = curr_log_probs / length_penalty
    vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1]

    # Select top-k candidates
    # [batch_size, beam_size * vocab_size]
    curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])
    # [batch_size, 2 * beam_size]
    top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
    # Shape: [batch_size, 2 * beam_size]
    beam_indices = top_indices // vocab_size
    symbol_indices = top_indices % vocab_size
    # Expand sequences
    # [batch_size, 2 * beam_size, time]
    candidate_seqs = utils.gather_2d(seqs, beam_indices)
    candidate_seqs = tf.concat(
        [candidate_seqs, tf.expand_dims(symbol_indices, 2)], 2)

    # Expand sequences
    # Suppress finished sequences
    flags = tf.equal(symbol_indices, eos_id)
    # [batch, 2 * beam_size]
    alive_scores = top_scores + tf.to_float(flags) * tf.float32.min
    # [batch, beam_size]
    alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
    alive_symbols = utils.gather_2d(symbol_indices, alive_indices)
    alive_indices = utils.gather_2d(beam_indices, alive_indices)
    alive_seqs = utils.gather_2d(seqs, alive_indices)
    # [batch_size, beam_size, time + 1]
    alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2)
    alive_state = nest.map_structure(
        lambda x: utils.gather_2d(x, alive_indices), next_state)
    alive_log_probs = alive_scores * length_penalty

    # Select finished sequences
    prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
    # [batch, 2 * beam_size]
    step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min
    # [batch, 3 * beam_size]
    fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
    fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
    # [batch, beam_size]
    fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
    fin_flags = utils.gather_2d(fin_flags, fin_indices)
    pad_seqs = tf.fill([batch_size, beam_size, 1],
                       tf.constant(pad_id, tf.int32))
    prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
    fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
    fin_seqs = utils.gather_2d(fin_seqs, fin_indices)

    new_state = BeamSearchState(
        inputs=(alive_seqs, alive_log_probs, alive_scores),
        state=alive_state,
        finish=(fin_flags, fin_seqs, fin_scores),
    )

    return time + 1, new_state