예제 #1
0
def _beam_search_step(time,
                      func,
                      state,
                      batch_size,
                      beam_size,
                      alpha,
                      pad_id,
                      eos_id,
                      max_length,
                      inf=-1e9):
    # Compute log probabilities
    seqs, log_probs = state.inputs[:2]
    flat_seqs = _merge_first_two_dims(seqs)
    flat_state = map_structure(lambda x: _merge_first_two_dims(x), state.state)
    step_log_probs, next_state = func(flat_seqs, flat_state)
    step_log_probs = _split_first_two_dims(step_log_probs, batch_size,
                                           beam_size)
    next_state = map_structure(
        lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state)
    curr_log_probs = torch.unsqueeze(log_probs, 2) + step_log_probs

    # Apply length penalty
    length_penalty = ((5.0 + float(time + 1)) / 6.0)**alpha
    curr_scores = curr_log_probs / length_penalty
    vocab_size = curr_scores.shape[-1]

    # Select top-k candidates
    # [batch_size, beam_size * vocab_size]
    curr_scores = torch.reshape(curr_scores, [-1, beam_size * vocab_size])
    # [batch_size, 2 * beam_size]
    top_scores, top_indices = torch.topk(curr_scores, k=2 * beam_size)
    # Shape: [batch_size, 2 * beam_size]
    beam_indices = top_indices // vocab_size
    symbol_indices = top_indices % vocab_size
    # Expand sequences
    # [batch_size, 2 * beam_size, time]
    candidate_seqs = _gather_2d(seqs, beam_indices)
    candidate_seqs = torch.cat(
        [candidate_seqs, torch.unsqueeze(symbol_indices, 2)], 2)

    # Expand sequences
    # Suppress finished sequences
    flags = torch.eq(symbol_indices, eos_id).to(torch.bool)
    # [batch, 2 * beam_size]
    alive_scores = top_scores + flags.to(torch.float32) * inf
    # [batch, beam_size]
    alive_scores, alive_indices = torch.topk(alive_scores, beam_size)
    alive_symbols = _gather_2d(symbol_indices, alive_indices)
    alive_indices = _gather_2d(beam_indices, alive_indices)
    alive_seqs = _gather_2d(seqs, alive_indices)
    # [batch_size, beam_size, time + 1]
    alive_seqs = torch.cat([alive_seqs, torch.unsqueeze(alive_symbols, 2)], 2)
    alive_state = map_structure(lambda x: _gather_2d(x, alive_indices),
                                next_state)
    alive_log_probs = alive_scores * length_penalty
    # Check length constraint
    length_flags = torch.le(max_length, time + 1).float()
    alive_log_probs = alive_log_probs + length_flags * inf
    alive_scores = alive_scores + length_flags * inf

    # Select finished sequences
    prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
    # [batch, 2 * beam_size]
    step_fin_scores = top_scores + (1.0 - flags.to(torch.float32)) * inf
    # [batch, 3 * beam_size]
    fin_flags = torch.cat([prev_fin_flags, flags], dim=1)
    fin_scores = torch.cat([prev_fin_scores, step_fin_scores], dim=1)
    # [batch, beam_size]
    fin_scores, fin_indices = torch.topk(fin_scores, beam_size)
    fin_flags = _gather_2d(fin_flags, fin_indices)
    pad_seqs = prev_fin_seqs.new_full([batch_size, beam_size, 1], pad_id)
    prev_fin_seqs = torch.cat([prev_fin_seqs, pad_seqs], dim=2)
    fin_seqs = torch.cat([prev_fin_seqs, candidate_seqs], dim=1)
    fin_seqs = _gather_2d(fin_seqs, fin_indices)

    new_state = BeamSearchState(
        inputs=(alive_seqs, alive_log_probs, alive_scores),
        state=alive_state,
        finish=(fin_flags, fin_seqs, fin_scores),
    )

    return new_state
예제 #2
0
def beam_search(models, features, params):
    if not isinstance(models, (list, tuple)):
        raise ValueError("'models' must be a list or tuple")

    beam_size = params.beam_size
    top_beams = params.top_beams
    alpha = params.decode_alpha
    decode_ratio = params.decode_ratio
    decode_length = params.decode_length

    pad_id = params.lookup["target"][params.pad.encode("utf-8")]
    bos_id = params.lookup["target"][params.bos.encode("utf-8")]
    eos_id = params.lookup["target"][params.eos.encode("utf-8")]

    min_val = -1e9
    shape = features["source"].shape
    device = features["source"].device
    batch_size = shape[0]
    seq_length = shape[1]

    # Compute initial state if necessary
    states = []
    funcs = []

    for model in models:
        state = model.empty_state(batch_size, device)
        states.append(model.encode(features, state))
        funcs.append(model.decode)

    # For source sequence length
    max_length = features["source_mask"].sum(1) * decode_ratio
    max_length = max_length.long() + decode_length
    max_step = max_length.max()
    # [batch, beam_size]
    max_length = torch.unsqueeze(max_length, 1).repeat([1, beam_size])

    # Expand the inputs
    # [batch, length] => [batch * beam_size, length]
    features["source"] = torch.unsqueeze(features["source"], 1)
    features["source"] = features["source"].repeat([1, beam_size, 1])
    features["source"] = torch.reshape(features["source"],
                                       [batch_size * beam_size, seq_length])
    features["source_mask"] = torch.unsqueeze(features["source_mask"], 1)
    features["source_mask"] = features["source_mask"].repeat([1, beam_size, 1])
    features["source_mask"] = torch.reshape(
        features["source_mask"], [batch_size * beam_size, seq_length])

    decoding_fn = _get_inference_fn(funcs, features)

    states = map_structure(lambda x: _tile_to_beam_size(x, beam_size), states)

    # Initial beam search state
    init_seqs = torch.full([batch_size, beam_size, 1],
                           bos_id,
                           device=device,
                           dtype=torch.long)
    #init_seqs = init_seqs.long()
    init_log_probs = init_seqs.new_tensor([[0.] + [min_val] * (beam_size - 1)],
                                          dtype=torch.float32)
    init_log_probs = init_log_probs.repeat([batch_size, 1])
    init_scores = torch.zeros_like(init_log_probs)
    fin_seqs = torch.zeros([batch_size, beam_size, 1],
                           dtype=torch.int64,
                           device=device)
    fin_scores = torch.full([batch_size, beam_size],
                            min_val,
                            dtype=torch.float32,
                            device=device)
    fin_flags = torch.zeros([batch_size, beam_size],
                            dtype=torch.bool,
                            device=device)

    state = BeamSearchState(
        inputs=(init_seqs, init_log_probs, init_scores),
        state=states,
        finish=(fin_flags, fin_seqs, fin_scores),
    )

    for time in range(max_step):
        state = _beam_search_step(time, decoding_fn, state, batch_size,
                                  beam_size, alpha, pad_id, eos_id, max_length)
        max_penalty = ((5.0 + max_step) / 6.0)**alpha
        best_alive_score = torch.max(state.inputs[1][:, 0] / max_penalty)
        worst_finished_score = torch.min(state.finish[2])
        cond = torch.gt(worst_finished_score, best_alive_score)
        is_finished = bool(cond)

        if is_finished:
            break

    final_state = state
    alive_seqs = final_state.inputs[0]
    alive_scores = final_state.inputs[2]
    final_flags = final_state.finish[0].byte()
    final_seqs = final_state.finish[1]
    final_scores = final_state.finish[2]

    final_seqs = torch.where(final_flags[:, :, None], final_seqs, alive_seqs)
    final_scores = torch.where(final_flags, final_scores, alive_scores)

    final_mask = final_seqs[:, 0, 1:].eq(pad_id).eq(0)
    # Append extra <eos>
    final_seqs = torch.nn.functional.pad(final_seqs, (0, 1, 0, 0, 0, 0),
                                         value=eos_id)

    first_beam_states = map_structure(lambda x: x[:, 0], final_state.state)
    # update target starts according to actuall length
    for i, fb_state in enumerate(first_beam_states):
        first_beam_states[i]["target_mask"] = final_mask

    return final_seqs[:, :top_beams,
                      1:], final_scores[:, :top_beams], first_beam_states