Example #1
0
def beam_search(decoder: nn.Module,
                enc_out: th.Tensor,
                lm: Optional[LmType] = None,
                ctc_prob: Optional[th.Tensor] = None,
                lm_weight: float = 0,
                beam_size: int = 8,
                nbest: int = 1,
                max_len: int = -1,
                max_len_ratio: float = 1,
                min_len: int = 0,
                min_len_ratio: float = 0,
                sos: int = -1,
                eos: int = -1,
                unk: int = -1,
                len_norm: bool = True,
                ctc_weight: float = 0,
                end_detect: bool = False,
                len_penalty: float = 0,
                cov_penalty: float = 0,
                temperature: float = 1,
                allow_partial: bool = False,
                cov_threshold: float = 0.5,
                eos_threshold: float = 0) -> List[Dict]:
    """
    Vectorized beam search algothrim for transformer decoder
    Args
        enc_out (Tensor): T x 1 x F, encoder output
    """
    if sos < 0 or eos < 0:
        raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}")
    T, N, D_enc = enc_out.shape
    if N != 1:
        raise RuntimeError(
            f"Got batch size {N:d}, now only support one utterance")
    if not hasattr(decoder, "step"):
        raise RuntimeError("Function step should defined in decoder network")
    if beam_size > decoder.vocab_size:
        raise RuntimeError(f"Beam size({beam_size}) > vocabulary size")

    min_len = max(min_len, int(min_len_ratio * T))
    max_len = min(max_len, int(max_len_ratio * T)) if max_len_ratio > 0 else T
    logger.info(f"--- shape of the encoder output: {T} x {D_enc}")
    logger.info("--- length constraint of the decoding " +
                f"sequence: ({min_len}, {max_len})")
    nbest = min(beam_size, nbest)
    device = enc_out.device

    # cov_* are diabled
    beam_param = BeamSearchParam(beam_size=beam_size,
                                 sos=sos,
                                 eos=eos,
                                 unk=unk,
                                 device=device,
                                 min_len=min_len,
                                 max_len=max_len,
                                 len_norm=len_norm,
                                 lm_weight=lm_weight,
                                 ctc_weight=ctc_weight,
                                 end_detect=end_detect,
                                 len_penalty=len_penalty,
                                 allow_partial=allow_partial,
                                 eos_threshold=eos_threshold,
                                 ctc_beam_size=int(beam_size * 1.5))
    beam_tracker = BeamTracker(beam_param, ctc_prob=ctc_prob)
    pre_emb = None
    lm_state = None
    # T x 1 x D => T x beam x D
    enc_out = th.repeat_interleave(enc_out, beam_size, 1)
    # step by step
    stop = False
    while not stop:
        # beam
        pre_tok, point = beam_tracker[-1]
        # beam x V
        dec_out, pre_emb = decoder.step(
            enc_out,
            pre_tok[:, None],
            out_idx=-1,
            pre_emb=None if pre_emb is None else pre_emb[:, point])

        # compute prob: beam x V, nagetive
        am_prob = tf.log_softmax(dec_out / temperature, dim=-1)
        if lm and beam_param.lm_weight > 0:
            # beam x V
            lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state)
        else:
            lm_prob = 0
        # one beam search step
        stop = beam_tracker.step(am_prob, lm_prob)
    # return nbest
    return beam_tracker.nbest_hypos(nbest)
Example #2
0
def beam_search(decoder: nn.Module,
                att_net: nn.Module,
                enc_out: th.Tensor,
                lm: Optional[LmType] = None,
                lm_weight: float = 0,
                beam_size: int = 8,
                nbest: int = 1,
                max_len: int = -1,
                max_len_ratio: float = 1,
                min_len: int = 0,
                min_len_ratio: float = 0,
                sos: int = -1,
                eos: int = -1,
                len_norm: bool = True,
                len_penalty: float = 0,
                cov_penalty: float = 0,
                temperature: float = 1,
                cov_threshold: float = 0.5,
                eos_threshold: float = 1) -> List[Dict]:
    """
    Vectorized beam search algothrim (see batch version beam_search_batch)
    Args
        att_net (nn.Module): attention network
        enc_out (Tensor): 1 x T x F, encoder output
    """
    if sos < 0 or eos < 0:
        raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}")
    N, T, D_enc = enc_out.shape
    if N != 1:
        raise RuntimeError(
            f"Got batch size {N:d}, now only support one utterance")
    if not hasattr(decoder, "step"):
        raise RuntimeError("Function step should defined in decoder network")
    if beam_size > decoder.vocab_size:
        raise RuntimeError(f"Beam size({beam_size}) > vocabulary size")

    min_len = max(min_len, int(min_len_ratio * T))
    max_len = min(max_len, int(max_len_ratio * T))
    logger.info(f"--- shape of the encoder output: {T} x {D_enc}")
    logger.info("--- length constraint of the decoding " +
                f"sequence: ({min_len}, {max_len})")
    nbest = min(beam_size, nbest)
    device = enc_out.device
    att_ali = None
    dec_hid = None
    # N x T x F => N*beam x T x F
    enc_out = th.repeat_interleave(enc_out, beam_size, 0)
    att_ctx = th.zeros([N * beam_size, D_enc], device=device)
    proj = th.zeros([N * beam_size, D_enc], device=device)

    beam_param = BeamSearchParam(beam_size=beam_size,
                                 sos=sos,
                                 eos=eos,
                                 device=device,
                                 min_len=min_len,
                                 max_len=max_len,
                                 len_norm=len_norm,
                                 lm_weight=lm_weight,
                                 len_penalty=len_penalty,
                                 cov_penalty=cov_penalty,
                                 cov_threshold=cov_threshold,
                                 eos_threshold=eos_threshold)
    beam_tracker = BeamTracker(beam_param)

    lm_state = None
    # clear states
    att_net.clear()
    # step by step
    stop = False
    while not stop:
        # beam
        pre_tok, point = beam_tracker[-1]

        # step forward
        dec_hid = adjust_hidden(point, dec_hid)
        att_ali = None if att_ali is None else att_ali[point]
        dec_out, att_ctx, dec_hid, att_ali, proj = decoder.step(
            att_net,
            pre_tok,
            enc_out,
            att_ctx[point],
            dec_hid=dec_hid,
            att_ali=att_ali,
            proj=proj[point])
        # compute prob: beam x V, nagetive
        am_prob = tf.log_softmax(dec_out / temperature, dim=-1)
        if lm and beam_param.lm_weight > 0:
            # beam x V
            lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state)
        else:
            lm_prob = 0
        # one beam search step
        stop = beam_tracker.step(am_prob, lm_prob, att_ali=att_ali)
    # return nbest
    return beam_tracker.nbest_hypos(nbest)