def beam_search(decoder: nn.Module, enc_out: th.Tensor, lm: Optional[LmType] = None, ctc_prob: Optional[th.Tensor] = None, lm_weight: float = 0, beam_size: int = 8, nbest: int = 1, max_len: int = -1, max_len_ratio: float = 1, min_len: int = 0, min_len_ratio: float = 0, sos: int = -1, eos: int = -1, unk: int = -1, len_norm: bool = True, ctc_weight: float = 0, end_detect: bool = False, len_penalty: float = 0, cov_penalty: float = 0, temperature: float = 1, allow_partial: bool = False, cov_threshold: float = 0.5, eos_threshold: float = 0) -> List[Dict]: """ Vectorized beam search algothrim for transformer decoder Args enc_out (Tensor): T x 1 x F, encoder output """ if sos < 0 or eos < 0: raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}") T, N, D_enc = enc_out.shape if N != 1: raise RuntimeError( f"Got batch size {N:d}, now only support one utterance") if not hasattr(decoder, "step"): raise RuntimeError("Function step should defined in decoder network") if beam_size > decoder.vocab_size: raise RuntimeError(f"Beam size({beam_size}) > vocabulary size") min_len = max(min_len, int(min_len_ratio * T)) max_len = min(max_len, int(max_len_ratio * T)) if max_len_ratio > 0 else T logger.info(f"--- shape of the encoder output: {T} x {D_enc}") logger.info("--- length constraint of the decoding " + f"sequence: ({min_len}, {max_len})") nbest = min(beam_size, nbest) device = enc_out.device # cov_* are diabled beam_param = BeamSearchParam(beam_size=beam_size, sos=sos, eos=eos, unk=unk, device=device, min_len=min_len, max_len=max_len, len_norm=len_norm, lm_weight=lm_weight, ctc_weight=ctc_weight, end_detect=end_detect, len_penalty=len_penalty, allow_partial=allow_partial, eos_threshold=eos_threshold, ctc_beam_size=int(beam_size * 1.5)) beam_tracker = BeamTracker(beam_param, ctc_prob=ctc_prob) pre_emb = None lm_state = None # T x 1 x D => T x beam x D enc_out = th.repeat_interleave(enc_out, beam_size, 1) # step by step stop = False while not stop: # beam pre_tok, point = beam_tracker[-1] # beam x V dec_out, pre_emb = decoder.step( enc_out, pre_tok[:, None], out_idx=-1, pre_emb=None if pre_emb is None else pre_emb[:, point]) # compute prob: beam x V, nagetive am_prob = tf.log_softmax(dec_out / temperature, dim=-1) if lm and beam_param.lm_weight > 0: # beam x V lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state) else: lm_prob = 0 # one beam search step stop = beam_tracker.step(am_prob, lm_prob) # return nbest return beam_tracker.nbest_hypos(nbest)
def beam_search(decoder: nn.Module, att_net: nn.Module, enc_out: th.Tensor, lm: Optional[LmType] = None, lm_weight: float = 0, beam_size: int = 8, nbest: int = 1, max_len: int = -1, max_len_ratio: float = 1, min_len: int = 0, min_len_ratio: float = 0, sos: int = -1, eos: int = -1, len_norm: bool = True, len_penalty: float = 0, cov_penalty: float = 0, temperature: float = 1, cov_threshold: float = 0.5, eos_threshold: float = 1) -> List[Dict]: """ Vectorized beam search algothrim (see batch version beam_search_batch) Args att_net (nn.Module): attention network enc_out (Tensor): 1 x T x F, encoder output """ if sos < 0 or eos < 0: raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}") N, T, D_enc = enc_out.shape if N != 1: raise RuntimeError( f"Got batch size {N:d}, now only support one utterance") if not hasattr(decoder, "step"): raise RuntimeError("Function step should defined in decoder network") if beam_size > decoder.vocab_size: raise RuntimeError(f"Beam size({beam_size}) > vocabulary size") min_len = max(min_len, int(min_len_ratio * T)) max_len = min(max_len, int(max_len_ratio * T)) logger.info(f"--- shape of the encoder output: {T} x {D_enc}") logger.info("--- length constraint of the decoding " + f"sequence: ({min_len}, {max_len})") nbest = min(beam_size, nbest) device = enc_out.device att_ali = None dec_hid = None # N x T x F => N*beam x T x F enc_out = th.repeat_interleave(enc_out, beam_size, 0) att_ctx = th.zeros([N * beam_size, D_enc], device=device) proj = th.zeros([N * beam_size, D_enc], device=device) beam_param = BeamSearchParam(beam_size=beam_size, sos=sos, eos=eos, device=device, min_len=min_len, max_len=max_len, len_norm=len_norm, lm_weight=lm_weight, len_penalty=len_penalty, cov_penalty=cov_penalty, cov_threshold=cov_threshold, eos_threshold=eos_threshold) beam_tracker = BeamTracker(beam_param) lm_state = None # clear states att_net.clear() # step by step stop = False while not stop: # beam pre_tok, point = beam_tracker[-1] # step forward dec_hid = adjust_hidden(point, dec_hid) att_ali = None if att_ali is None else att_ali[point] dec_out, att_ctx, dec_hid, att_ali, proj = decoder.step( att_net, pre_tok, enc_out, att_ctx[point], dec_hid=dec_hid, att_ali=att_ali, proj=proj[point]) # compute prob: beam x V, nagetive am_prob = tf.log_softmax(dec_out / temperature, dim=-1) if lm and beam_param.lm_weight > 0: # beam x V lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state) else: lm_prob = 0 # one beam search step stop = beam_tracker.step(am_prob, lm_prob, att_ali=att_ali) # return nbest return beam_tracker.nbest_hypos(nbest)