def _target_mask(self, olens: torch.Tensor) -> torch.Tensor: """Make masks for masked self-attention. Args: olens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for masked self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks
def test_encoder_cache(normalize_before): adim = 4 idim = 5 encoder = Encoder( idim=idim, attention_dim=adim, linear_units=3, num_blocks=2, normalize_before=normalize_before, dropout_rate=0.0, input_layer="embed", ) elayer = encoder.encoders[0] x = torch.randn(2, 5, adim) mask = subsequent_mask(x.shape[1]).unsqueeze(0) prev_mask = mask[:, :-1, :-1] encoder.eval() with torch.no_grad(): # layer-level test y = elayer(x, mask, None)[0] cache = elayer(x[:, :-1], prev_mask, None)[0] y_fast = elayer(x, mask, cache=cache)[0] numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=RTOL) # encoder-level test x = torch.randint(0, idim, x.shape[:2]) y = encoder.forward_one_step(x, mask)[0] y_, _, cache = encoder.forward_one_step(x[:, :-1], prev_mask) y_fast, _, _ = encoder.forward_one_step(x, mask, cache=cache) numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=RTOL)
def score(self, ys, state, x): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state ) return logp.squeeze(0), state
def test_decoder_cache(normalize_before): adim = 4 odim = 5 decoder = Decoder( odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, normalize_before=normalize_before, dropout_rate=0.0, ) dlayer = decoder.decoders[0] memory = torch.randn(2, 5, adim) x = torch.randn(2, 5, adim) * 100 mask = subsequent_mask(x.shape[1]).unsqueeze(0) prev_mask = mask[:, :-1, :-1] decoder.eval() with torch.no_grad(): # layer-level test y = dlayer(x, mask, memory, None)[0] cache = dlayer(x[:, :-1], prev_mask, memory, None)[0] y_fast = dlayer(x, mask, memory, None, cache=cache)[0] numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=RTOL) # decoder-level test x = torch.randint(0, odim, x.shape[:2]) y, _ = decoder.forward_one_step(x, mask, memory) y_, cache = decoder.forward_one_step( x[:, :-1], prev_mask, memory, cache=decoder.init_state(None) ) y_fast, _ = decoder.forward_one_step(x, mask, memory, cache=cache) numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=RTOL)
def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 if self.export_mode: m = _subsequent_mask(ys_mask).unsqueeze(0) else: m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(0) & m
def compute_hyps(self, current_hyps, curren_frame, total_frame, enc_output, hat_att, enc_mask, chunk=True): for length, hyps_t in current_hyps.items(): ys_mask = subsequent_mask(length).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) # print(ys_mask4use.shape) l_id = [hyp_t['yseq'] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if hyps_t[0]["cache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["cache"])): current_cache = [] for hyp_t in hyps_t: current_cache.append( hyp_t["cache"][decode_num].squeeze(0)) # print( torch.stack(current_cache).shape) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte()) align = [0] * length align[:length - 1] = hyp_t['last_time'][:] align[-1] = curren_frame align_tensor = torch.tensor(align).unsqueeze(0) if chunk: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = self.decoder.forward_one_step( ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0) hat_att[hyp_t['seq']] = {} hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache'] hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att']
def score(self, ys, state, x): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) if self.selfattention_layer_type != "selfattn": state = None logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state ) return logp.squeeze(0), state
def compute_hyps_ctc(self, hyps_ctc_cluster, total_frame, enc_output, hat_att, enc_mask, chunk=True): for length, hyps_t in hyps_ctc_cluster.items(): ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["precache"])): current_cache = [] for hyp_t in hyps_t: # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape, # hyp_t["cache"][2].shape, hyp_t["cache"][4].shape) current_cache.append( hyp_t["precache"][decode_num].squeeze(0)) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte()) align = hyp_t['last_time'] align_tensor = torch.tensor(align).unsqueeze(0) if chunk: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = \ self.decoder.forward_one_step(ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cur_new_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_cur_att_scores'] = local_att_scores_b[ idx].unsqueeze(0) l_minus = ' '.join(hyp_t['seq'].split()[:-1]) hat_att[l_minus] = {} hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores'] hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']
def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad # tgt_mask: (B, 1, L) tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m memory = hs_pad memory_mask = ( ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( memory.device) # Padding for Longformer if memory_mask.shape[-1] != memory.shape[1]: padlen = memory.shape[1] - memory_mask.shape[-1] memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False) x = self.embed(tgt) x, tgt_mask, memory, memory_mask = self.decoders( x, tgt_mask, memory, memory_mask) if self.normalize_before: x = self.after_norm(x) if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) return x, olens
def score(self, ys, state, x): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( f"{self.selfattention_layer_type} does not support cached decoding." ) state = None logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state ) return logp.squeeze(0), state
def forward_ilm( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad # tgt_mask: (B, 1, L) tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m x = self.embed(tgt) x, tgt_mask, _, _ = self.decoders(x, tgt_mask, -1, -1, ilm=True) if self.normalize_before: x = self.after_norm(x) if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) return x, olens
def score(self, hyp, cache): """Forward one step. Args: hyp (dataclass): hypothesis cache (dict): states cache Returns: y (torch.Tensor): decoder outputs (1, dec_dim) (list): decoder states [L x (1, max_len, dec_dim)] lm_tokens (torch.Tensor): token id for LM (1) """ device = next(self.parameters()).device tgt = torch.tensor(hyp.yseq).unsqueeze(0).to(device=device) lm_tokens = tgt[:, -1] str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: y, new_state = cache[str_yseq] else: tgt_mask = subsequent_mask(len( hyp.yseq)).unsqueeze(0).to(device=device) state = check_state(hyp.dec_state, (tgt.size(1) - 1), self.blank) tgt = self.embed(tgt) new_state = [] for s, decoder in zip(state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) new_state.append(tgt) y = self.after_norm(tgt[:, -1]) cache[str_yseq] = (y, new_state) return y[0], new_state, lm_tokens
def score( self, hyp: Hypothesis, cache: Dict[str, Any] ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypothesis. Args: hyp: Hypothesis. cache: Pairs of (dec_out, dec_state) for each label sequence. (key) Returns: dec_out: Decoder output sequence. (1, D_dec) dec_state: Decoder hidden states. [N x (1, U, D_dec)] lm_label: Label ID for LM. (1,) """ labels = torch.tensor([hyp.yseq], device=self.device) lm_label = labels[:, -1] str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: dec_out, dec_state = cache[str_labels] else: dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0) new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id) dec_out = self.embed(labels) dec_state = [] for s, decoder in zip(new_state, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) dec_state.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) cache[str_labels] = (dec_out, dec_state) return dec_out[0], dec_state, lm_label
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list
def score(self, ys, state, x): # TODO(karita) cache previous attentions in state ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) logp = self.recognize(ys.unsqueeze(0), ys_mask, x.unsqueeze(0)) return logp.squeeze(0), None
def translate( self, x, trans_args, char_list=None, rnnlm=None, use_jit=False, ): """Translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ # preprate sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) enc_output = self.encode(x).unsqueeze(0) h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + trans_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += trans_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform translation " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps
def translate( self, x, trans_args, char_list=None, ): """Translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :return: N-best decoding results :rtype: list """ # preprate sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) logging.info("input lengths: " + str(x.shape[0])) enc_output = self.encode(x).unsqueeze(0) h = enc_output logging.info("encoder output lengths: " + str(h.size(1))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty if trans_args.maxlenratio == 0: maxlen = h.size(1) else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) minlen = int(trans_args.minlenratio * h.size(1)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis hyp = {"score": 0.0, "yseq": [y]} hyps = [hyp] ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) # batchfy ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, hyp in enumerate(hyps): ys[j, :] = torch.tensor(hyp["yseq"]) ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) local_scores = self.decoder.forward_one_step( ys, ys_mask, h.repeat([len(hyps), 1, 1]) )[0] hyps_best_kept = [] for j, hyp in enumerate(hyps): local_best_scores, local_best_ids = torch.topk( local_scores[j : j + 1], beam, dim=1 ) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), trans_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform translation " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps
def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T_text,). feats (Optional[Tensor]): Feature sequence to extract style embedding (T_feats', idim). spembs (Optional[Tensor]): Speaker embedding (spk_embed_dim,). sids (Optional[Tensor]): Speaker ID (1,). lids (Optional[Tensor]): Language ID (1,). threshold (float): Threshold in inference. minlenratio (float): Minimum length ratio in inference. maxlenratio (float): Maximum length ratio in inference. use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). * prob (Tensor): Output sequence of stop probabilities (T_feats,). * att_w (Tensor): Source attn weight (#layers, #heads, T_feats, T_text). """ x = text y = feats spemb = spembs # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert feats is not None, "feats must be provided with teacher forcing." # get teacher forcing outputs xs, ys = x.unsqueeze(0), y.unsqueeze(0) spembs = None if spemb is None else spemb.unsqueeze(0) ilens = x.new_tensor([xs.size(1)]).long() olens = y.new_tensor([ys.size(1)]).long() outs, *_ = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # get attention weights att_ws = [] for i in range(len(self.decoder.decoders)): att_ws += [self.decoder.decoders[i].src_attn.attn] att_ws = torch.stack(att_ws, dim=1) # (B, L, H, T_feats, T_text) return dict(feat_gen=outs[0], att_w=att_ws[0]) # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate GST if self.use_gst: style_embs = self.gst(y.unsqueeze(0)) hs = hs + style_embs.unsqueeze(1) # integrate spk & lang embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step z_cache = self.decoder.init_state(x) while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z, z_cache = self.decoder.forward_one_step( ys, y_masks, hs, cache=z_cache) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # get attention weights att_ws_ = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws_ += [m.attn[0, :, -1].unsqueeze(1) ] # [(#heads, 1, T),...] if idx == 1: att_ws = att_ws_ else: # [(#heads, l, T), ...] att_ws = [ torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = ( torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) ) # (T_feats, odim) -> (1, T_feats, odim) -> (1, odim, T_feats) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, T_feats) outs = outs.transpose(2, 1).squeeze(0) # (T_feats, odim) probs = torch.cat(probs, dim=0) break # concatenate attention weights -> (#layers, #heads, T_feats, T_text) att_ws = torch.stack(att_ws, dim=0) return dict(feat_gen=outs, prob=probs, att_w=att_ws)
def recognize_beam(self, h, recog_args, rnnlm=None): """Beam search implementation for transformer-transducer. Args: h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc) recog_args (Namespace): argument Namespace containing options rnnlm (torch.nn.Module): language model module Returns: nbest_hyps (list of dicts): n-best decoding results """ beam = recog_args.beam_size k_range = min(beam, self.odim) nbest = recog_args.nbest normscore = recog_args.score_norm_transducer if rnnlm: kept_hyps = [{ 'score': 0.0, 'yseq': [self.blank], 'cache': None, 'lm_state': None }] else: kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}] for i, hi in enumerate(h): hyps = kept_hyps kept_hyps = [] while True: new_hyp = max(hyps, key=lambda x: x['score']) hyps.remove(new_hyp) ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0)) ys_mask = to_device( self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0)) y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache']) ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0) if rnnlm: rnnlm_state, rnnlm_scores = rnnlm.predict( new_hyp['lm_state'], ys[:, -1]) for k in six.moves.range(self.odim): beam_hyp = { 'score': new_hyp['score'] + float(ytu[k]), 'yseq': new_hyp['yseq'][:], 'cache': new_hyp['cache'] } if rnnlm: beam_hyp['lm_state'] = new_hyp['lm_state'] if k == self.blank: kept_hyps.append(beam_hyp) else: beam_hyp['yseq'].append(int(k)) beam_hyp['cache'] = c if rnnlm: beam_hyp['lm_state'] = rnnlm_state beam_hyp[ 'score'] += recog_args.lm_weight * rnnlm_scores[ 0][k] hyps.append(beam_hyp) if len(kept_hyps) >= k_range: break if normscore: nbest_hyps = sorted(kept_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest] else: nbest_hyps = sorted(kept_hyps, key=lambda x: x['score'], reverse=True)[:nbest] return nbest_hyps
def batch_score( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]], dec_states: List[Optional[torch.Tensor]], cache: Dict[str, Any], use_lm: bool, ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypotheses. Args: hyps: Hypotheses. dec_states: Decoder hidden states. [N x (B, U, D_dec)] cache: Pairs of (h_dec, dec_states) for each label sequences. (keys) use_lm: Whether to compute label ID sequences for LM. Returns: dec_out: Decoder output sequences. (B, D_dec) dec_states: Decoder hidden states. [N x (B, U, D_dec)] lm_labels: Label ID sequences for LM. (B,) """ final_batch = len(hyps) process = [] done = [None] * final_batch for i, hyp in enumerate(hyps): str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: done[i] = cache[str_labels] else: process.append((str_labels, hyp.yseq, hyp.dec_state)) if process: labels = pad_sequence([p[1] for p in process], self.blank_id) labels = torch.LongTensor(labels, device=self.device) p_dec_states = self.create_batch_states( self.init_state(), [p[2] for p in process], labels, ) dec_out = self.embed(labels) dec_out_mask = (subsequent_mask( labels.size(-1)).unsqueeze_(0).expand(len(process), -1, -1)) new_states = [] for s, decoder in zip(p_dec_states, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) new_states.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: state = self.select_state(new_states, j) done[i] = (dec_out[j], state) cache[process[j][0]] = (dec_out[j], state) j += 1 dec_out = torch.stack([d[0] for d in done]) dec_states = self.create_batch_states(dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps]) if use_lm: lm_labels = torch.LongTensor([hyp.yseq[-1] for hyp in hyps], device=self.device) return dec_out, dec_states, lm_labels return dec_out, dec_states, None
def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m
def recognize_and_translate_sum(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False, decode_asr_weight=1.0, score_is_prob=False, ratio_diverse_st=0.0, ratio_diverse_asr=0.0, debug=False): """Recognize and translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ assert self.do_asr, "Recognize and translate are performed simultaneously." logging.info(f'| ratio_diverse_st = {ratio_diverse_st}') logging.info(f'| ratio_diverse_asr = {ratio_diverse_asr}') # prepare sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos if self.one_to_many and self.lang_tok == 'decoder-pre': tgt_lang_id = '<2{}>'.format(trans_args.config.split('.')[-2].split('-')[-1]) y = char_list.index(tgt_lang_id) logging.info(f'tgt_lang_id: {tgt_lang_id} - y: {y}') src_lang_id = '<2{}>'.format(trans_args.config.split('.')[-2].split('-')[0]) y_asr = char_list.index(src_lang_id) logging.info(f'src_lang_id: {src_lang_id} - y_asr: {y_asr}') else: y = self.sos y_asr = self.sos logging.info(f'<sos> index: {str(y)}; <sos> mark: {char_list[y]}') logging.info(f'<sos> index asr: {str(y_asr)}; <sos> mark asr: {char_list[y_asr]}') enc_output = self.encode(x).unsqueeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) if trans_args.maxlenratio_asr == 0: maxlen_asr = h.shape[0] else: maxlen_asr = max(1, int(trans_args.maxlenratio_asr * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) minlen_asr = int(trans_args.minlenratio_asr * h.size(0)) logging.info(f'max output length: {str(maxlen)}; min output length: {str(minlen)}') logging.info(f'max output length asr: {str(maxlen_asr)}; min output length asr: {str(minlen_asr)}') # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: logging.info('initializing hypothesis...') hyp = {'score': 0.0, 'yseq': [y], 'yseq_asr': [y_asr]} hyps = [hyp] ended_hyps = [] traced_decoder = None for i in six.moves.range(max(maxlen, maxlen_asr)): logging.info('position ' + str(i)) hyps_best_kept = [] for idx, hyp in enumerate(hyps): if self.wait_k_asr > 0: if i < self.wait_k_asr: ys_mask = subsequent_mask(1).unsqueeze(0) else: ys_mask = subsequent_mask(i - self.wait_k_asr + 1).unsqueeze(0) else: ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) if self.wait_k_st > 0: if i < self.wait_k_st: ys_mask_asr = subsequent_mask(1).unsqueeze(0) else: ys_mask_asr = subsequent_mask(i - self.wait_k_st + 1).unsqueeze(0) else: ys_mask_asr = subsequent_mask(i + 1).unsqueeze(0) ys_asr = torch.tensor(hyp['yseq_asr']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace(self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: if hyp['yseq'][-1] != self.eos or hyp['yseq_asr'][-1] != self.eos or i < 2: cross_mask = create_cross_mask(ys, ys_asr, self.ignore_id, wait_k_cross=self.wait_k_asr) cross_mask_asr = create_cross_mask(ys_asr, ys, self.ignore_id, wait_k_cross=self.wait_k_st) local_att_scores, _, local_att_scores_asr, _ = self.dual_decoder.forward_one_step(ys, ys_mask, ys_asr, ys_mask_asr, enc_output, cross_mask=cross_mask, cross_mask_asr=cross_mask_asr, cross_self=self.cross_self, cross_src=self.cross_src, cross_self_from=self.cross_self_from, cross_src_from=self.cross_src_from) if (hyp['yseq'][-1] == self.eos and i > 2) or i < self.wait_k_asr: local_att_scores = None if (hyp['yseq_asr'][-1] == self.eos and i > 2) or i < self.wait_k_st: local_att_scores_asr = None if local_att_scores is not None and local_att_scores_asr is not None: # local_att_scores_asr = decode_asr_weight * local_att_scores_asr xk, ixk = local_att_scores.topk(beam) yk, iyk = local_att_scores_asr.topk(beam) S = (torch.mm(torch.t(xk), torch.ones_like(xk)) + torch.mm(torch.t(torch.ones_like(yk)), yk)) s2v = torch.LongTensor([[i, j] for i in ixk.squeeze(0) for j in iyk.squeeze(0)]) # (k^2) x 2 # Do not force diversity if ratio_diverse_st <= 0 and ratio_diverse_asr <=0: local_best_scores, id2k = S.view(-1).topk(beam) I = s2v[id2k] local_best_ids_st = I[:,0] local_best_ids_asr = I[:,1] # Force diversity for ST only if ratio_diverse_st > 0 and ratio_diverse_asr <= 0: ct = int((1 - ratio_diverse_st) * beam) # logging.info(f'ct = {ct}') s2v = s2v.reshape(beam, beam, 2) Sc = S[:, :ct] local_best_scores, id2k = Sc.flatten().topk(beam) I = s2v[:, :ct] I = I.reshape(-1, 2) I = I[id2k] local_best_ids_st = I[:,0] local_best_ids_asr = I[:,1] # Force diversity for ASR only if ratio_diverse_asr > 0 and ratio_diverse_st <= 0: cr = int((1 - ratio_diverse_asr) * beam) # logging.info(f'cr = {cr}') s2v = s2v.reshape(beam, beam, 2) Sc = S[:cr, :] local_best_scores, id2k = Sc.view(-1).topk(beam) I = s2v[:cr, :] I = I.reshape(-1, 2) I = I[id2k] local_best_ids_st = I[:,0] local_best_ids_asr = I[:,1] # Force diversity for both ST and ASR if ratio_diverse_st > 0 and ratio_diverse_asr > 0: cr = int((1 - ratio_diverse_asr) * beam) ct = int((1 - ratio_diverse_st) * beam) ct = max(ct, math.ceil(beam // cr)) # logging.info(f'cr = {cr}') # logging.info(f'ct = {ct}') s2v = s2v.reshape(beam, beam, 2) Sc = S[:cr, :ct] local_best_scores, id2k = Sc.flatten().topk(beam) I = s2v[:cr, :ct] I = I.reshape(-1, 2) I = I[id2k] local_best_ids_st = I[:,0] local_best_ids_asr = I[:,1] elif local_att_scores is not None: local_best_scores, local_best_ids_st = torch.topk(local_att_scores, beam, dim=1) local_best_scores = local_best_scores.squeeze(0) local_best_ids_st = local_best_ids_st.squeeze(0) elif local_att_scores_asr is not None: local_best_scores, local_best_ids_asr = torch.topk(local_att_scores_asr, beam, dim=1) local_best_ids_asr = local_best_ids_asr.squeeze(0) local_best_scores = local_best_scores.squeeze(0) else: raise NotImplementedError for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq_asr'] = [0] * (1 + len(hyp['yseq_asr'])) new_hyp['yseq_asr'][:len(hyp['yseq_asr'])] = hyp['yseq_asr'] if local_att_scores is not None: new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids_st[j]) else: if i >= self.wait_k_asr: new_hyp['yseq'][len(hyp['yseq'])] = self.eos else: new_hyp['yseq'] = hyp['yseq'] # v3 if local_att_scores_asr is not None: new_hyp['yseq_asr'][len(hyp['yseq_asr'])] = int(local_best_ids_asr[j]) else: if i >= self.wait_k_st: new_hyp['yseq_asr'][len(hyp['yseq_asr'])] = self.eos else: new_hyp['yseq_asr'] = hyp['yseq_asr'] # v3 hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.info('best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq']])) logging.info('best hypo asr: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq_asr']])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: if hyp['yseq'][-1] != self.eos: hyp['yseq'].append(self.eos) if i == maxlen_asr - 1: logging.info('adding <eos> in the last postion in the loop for asr') for hyp in hyps: if hyp['yseq_asr'][-1] != self.eos: hyp['yseq_asr'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos and hyp['yseq_asr'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen and len(hyp['yseq_asr']) > minlen_asr: hyp['score'] += (i + 1) * penalty # if rnnlm: # Word LM needs to add final <eos> score # hyp['score'] += trans_args.lm_weight * rnnlm.final( # hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.info('remained hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.info('hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.info('hypo asr: ' + ''.join([char_list[int(x)] for x in hyp['yseq_asr'][1:]])) logging.info('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy because Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) trans_args.minlenratio_asr = max(0.0, trans_args.minlenratio_asr - 0.1) return self.recognize_and_translate_sum(x, trans_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
dropout_rate=0.0) decoder.eval() else: encoder = Encoder( idim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, input_layer="embed") encoder.eval() xlen = 100 xs = torch.randint(0, odim, (1, xlen)) memory = torch.randn(2, 500, adim) mask = subsequent_mask(xlen).unsqueeze(0) result = {"cached": [], "baseline": []} n_avg = 10 for key, value in result.items(): cache = None print(key) for i in range(xlen): x = xs[:, :i + 1] m = mask[:, :i + 1, :i + 1] start = time() for _ in range(n_avg): with torch.no_grad(): if key == "baseline": cache = None if model == "decoder":
def forward(self, x): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 import onnxruntime # enc_output = self.encode(x).unsqueeze(0) encoder_rt = onnxruntime.InferenceSession("encoder.onnx") ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ctc_softmax_rt = onnxruntime.InferenceSession("ctc_softmax.onnx") ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores # ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) ctc_beam = 1 hyps = [hyp] ended_hyps = [] decoder_fos_rt = onnxruntime.InferenceSession("decoder_fos.onnx") for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = { "ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy() } local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) # local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0] local_scores = local_att_scores # print(local_scores.shape) 1, 7443 local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) # print(local_best_scores.shape) 1, 1 ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])
def batch_score(self, hyps, batch_states, cache): """Forward batch one step. Args: hyps (list): batch of hypotheses batch_states (list): decoder states [L x (B, max_len, dec_dim)] cache (dict): states cache Returns: batch_y (torch.Tensor): decoder output (B, dec_dim) batch_states (list): decoder states [L x (B, max_len, dec_dim)] lm_tokens (torch.Tensor): batch of token ids for LM (B) """ final_batch = len(hyps) device = next(self.parameters()).device process = [] done = [None for _ in range(final_batch)] for i, hyp in enumerate(hyps): str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: done[i] = (*cache[str_yseq], hyp.yseq) else: process.append((str_yseq, hyp.yseq, hyp.dec_state)) if process: batch = len(process) _tokens = pad_sequence([p[1] for p in process], self.blank) _states = [p[2] for p in process] batch_tokens = torch.LongTensor(_tokens).view(batch, -1).to(device=device) tgt_mask = (subsequent_mask( batch_tokens.size(-1)).unsqueeze(0).expand( batch, -1, -1).to(device=device)) dec_state = self.init_state() dec_state = self.create_batch_states( dec_state, _states, _tokens, ) tgt = self.embed(batch_tokens) next_state = [] for s, decoder in zip(dec_state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) next_state.append(tgt) tgt = self.after_norm(tgt[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: new_state = self.select_state(next_state, j) done[i] = (tgt[j], new_state, process[j][2]) cache[process[j][0]] = (tgt[j], new_state) j += 1 batch_states = self.create_batch_states(batch_states, [d[1] for d in done], [d[2] for d in done]) batch_y = torch.stack([d[0] for d in done]) lm_tokens = (torch.LongTensor([hyp.yseq[-1] for hyp in hyps ]).view(final_batch).to(device=device)) return batch_y, batch_states, lm_tokens
def translate(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False): """Translate source text. :param list x: input source text feature (T,) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ self.eval( ) # NOTE: this is important because self.encode() is not used assert isinstance(x, list) # make a utt list (1) to use the same interface for encoder if self.multilingual: x = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: x = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) xs_pad = x.unsqueeze(0) tgt_lang = None if trans_args.tgt_lang: tgt_lang = char_list.index(trans_args.tgt_lang) xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) enc_output, _ = self.encoder(xs_pad, None) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty # preprare sos y = self.sos vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += trans_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def recognize_jca(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) # shape of (T, D) else: lpz = None h = enc_output.squeeze(0) # (B, T, D), #B=1 logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: if self.remove_blank_in_ctc_mode: ctc_beam = lpz.shape[-1] - 1 # except blank else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze( 0) # mask scores of future state ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: if self.remove_blank_in_ctc_mode: # here we need to filter out <blank> in local_best_ids # it happens in pure ctc-mode, when ctc_beam equals to #vocab local_best_scores, local_best_ids = torch.topk( local_att_scores[:, 1:], ctc_beam, dim=1) local_best_ids += 1 # hack else: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection # from espnet.nets.e2e_asr_common import end_detect # if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: from espnet.nets.e2e_asr_common import end_detect_yzl23 if end_detect_yzl23(ended_hyps, remained_hyps, penalty) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def _word_target_mask(self, ys_in_pad, aver_mask): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) word_m = (aver_mask!=0) | m return ys_mask.unsqueeze(-2) & word_m
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps
def infer(x, encoder_rt, ctc_softmax_rt, decoder_fos_rt): ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 # enc_output = self.encode(x).unsqueeze(0) ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 ctc_beam = 1 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = {"ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy()} local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])