def beam_search(decoder: nn.Module, att_net: nn.Module, enc_out: th.Tensor, lm: Optional[LmType] = None, lm_weight: float = 0, beam_size: int = 8, nbest: int = 1, max_len: int = -1, max_len_ratio: float = 1, min_len: int = 0, min_len_ratio: float = 0, sos: int = -1, eos: int = -1, len_norm: bool = True, len_penalty: float = 0, cov_penalty: float = 0, temperature: float = 1, cov_threshold: float = 0.5, eos_threshold: float = 1) -> List[Dict]: """ Vectorized beam search algothrim (see batch version beam_search_batch) Args att_net (nn.Module): attention network enc_out (Tensor): 1 x T x F, encoder output """ if sos < 0 or eos < 0: raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}") N, T, D_enc = enc_out.shape if N != 1: raise RuntimeError( f"Got batch size {N:d}, now only support one utterance") if not hasattr(decoder, "step"): raise RuntimeError("Function step should defined in decoder network") if beam_size > decoder.vocab_size: raise RuntimeError(f"Beam size({beam_size}) > vocabulary size") min_len = max(min_len, int(min_len_ratio * T)) max_len = min(max_len, int(max_len_ratio * T)) logger.info(f"--- shape of the encoder output: {T} x {D_enc}") logger.info("--- length constraint of the decoding " + f"sequence: ({min_len}, {max_len})") nbest = min(beam_size, nbest) device = enc_out.device att_ali = None dec_hid = None # N x T x F => N*beam x T x F enc_out = th.repeat_interleave(enc_out, beam_size, 0) att_ctx = th.zeros([N * beam_size, D_enc], device=device) proj = th.zeros([N * beam_size, D_enc], device=device) beam_param = BeamSearchParam(beam_size=beam_size, sos=sos, eos=eos, device=device, min_len=min_len, max_len=max_len, len_norm=len_norm, lm_weight=lm_weight, len_penalty=len_penalty, cov_penalty=cov_penalty, cov_threshold=cov_threshold, eos_threshold=eos_threshold) beam_tracker = BeamTracker(beam_param) lm_state = None # clear states att_net.clear() # step by step stop = False while not stop: # beam pre_tok, point = beam_tracker[-1] # step forward dec_hid = adjust_hidden(point, dec_hid) att_ali = None if att_ali is None else att_ali[point] dec_out, att_ctx, dec_hid, att_ali, proj = decoder.step( att_net, pre_tok, enc_out, att_ctx[point], dec_hid=dec_hid, att_ali=att_ali, proj=proj[point]) # compute prob: beam x V, nagetive am_prob = tf.log_softmax(dec_out / temperature, dim=-1) if lm and beam_param.lm_weight > 0: # beam x V lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state) else: lm_prob = 0 # one beam search step stop = beam_tracker.step(am_prob, lm_prob, att_ali=att_ali) # return nbest return beam_tracker.nbest_hypos(nbest)
def beam_search_batch(decoder: nn.Module, att_net: nn.Module, enc_out: th.Tensor, enc_len: th.Tensor, lm: Optional[LmType] = None, lm_weight: float = 0, beam_size: int = 8, nbest: int = 1, max_len: int = -1, max_len_ratio: float = 1, min_len: int = 0, min_len_ratio: float = 0, sos: int = -1, eos: int = -1, len_norm: bool = True, end_detect: bool = False, ctc_weight: float = 0, len_penalty: float = 0, cov_penalty: float = 0, temperature: float = 1, allow_partial: bool = False, cov_threshold: float = 0.5, eos_threshold: float = 1) -> List[List[Dict]]: """ Batch level vectorized beam search algothrim Args att_net (nn.Module): attention network enc_out (Tensor): N x T x F, encoder output enc_len (Tensor): N, length of the encoder output """ if sos < 0 or eos < 0: raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}") if not hasattr(decoder, "step"): raise RuntimeError("Function step should defined in decoder network") if beam_size > decoder.vocab_size: raise RuntimeError(f"Beam size({beam_size}) > vocabulary size") N, T, D_enc = enc_out.shape min_len = [ max(min_len, int(min_len_ratio * elen.item())) for elen in enc_len ] max_len = [ min(max_len, int(max_len_ratio * elen.item())) if max_len_ratio > 0 else elen.item() for elen in enc_len ] logger.info(f"--- shape of the encoder output: {T} x {D_enc}") logger.info("--- length constraint of the decoding " + f"sequence: {[(i, j) for i, j in zip(min_len, max_len)]}") nbest = min(beam_size, nbest) device = enc_out.device att_ali = None dec_hid = None # N x T x F => N*beam x T x F enc_out = th.repeat_interleave(enc_out, beam_size, 0) enc_len = th.repeat_interleave(enc_len, beam_size, 0) att_ctx = th.zeros([N * beam_size, D_enc], device=device) proj = th.zeros([N * beam_size, D_enc], device=device) lm_state = None beam_param = BeamSearchParam(beam_size=beam_size, sos=sos, eos=eos, device=device, min_len=min_len, max_len=max_len, len_norm=len_norm, lm_weight=lm_weight, end_detect=end_detect, ctc_weight=ctc_weight, len_penalty=len_penalty, cov_penalty=cov_penalty, allow_partial=allow_partial, cov_threshold=cov_threshold, eos_threshold=eos_threshold) beam_tracker = BatchBeamTracker(N, beam_param) # clear states att_net.clear() # step by step stop = False while not stop: # N*beam pre_tok, point = beam_tracker[-1] # step forward dec_hid = adjust_hidden(point, dec_hid) att_ali = None if att_ali is None else att_ali[point] dec_out, att_ctx, dec_hid, att_ali, proj = decoder.step( att_net, pre_tok, enc_out, att_ctx[point], dec_hid=dec_hid, att_ali=att_ali, proj=proj[point]) # compute prob: N*beam x V, nagetive am_prob = tf.log_softmax(dec_out / temperature, dim=-1) if lm and beam_param.lm_weight > 0: # beam x V lm_prob, lm_state = lm_score_impl(lm, point, pre_tok, lm_state) else: lm_prob = 0 # one beam search step stop = beam_tracker.step(am_prob, lm_prob, att_ali=att_ali) # return nbest return beam_tracker.nbest_hypos(nbest, auto_stop=stop)