def beam_search_decode(self, encoder_output, dot_attention_mask, beam_size): batch_size = encoder_output.size(0) device = encoder_output.device feature_length = encoder_output.size(1) self.beam_steper = BeamSteper(batch_size, beam_size, self.bos_id, self.eos_id, self.vocab_size, device) encoder_output = encoder_output.unsqueeze(1).repeat( 1, beam_size, 1, 1).view(batch_size * beam_size, feature_length, -1) dot_attention_mask = dot_attention_mask.unsqueeze(1).repeat( 1, beam_size, 1, 1).view(batch_size * beam_size, 1, feature_length) with t.no_grad(): for i in range(self.max_length): try: length = self.beam_steper.length_container token_mask = Masker.get_mask(length) self_attention_mask = Masker.get_dot_mask( token_mask, token_mask) token_id = self.beam_steper.get_first_step_token( ) if i == 0 else self.beam_steper.token_container.view( batch_size * beam_size, -1) last_prob = self.beam_decode_step(token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask, topk=beam_size, return_last=True) self.beam_steper.step(last_prob) except: break return self.beam_steper.batch_best_saver
def greedy_decode(self, encoder_output, feature_mask): """ batched greedy decode """ batch_size = encoder_output.size(0) device = encoder_output.device token_id = t.full((batch_size, 1), fill_value=self.bos_id, dtype=t.long, device=device) length = t.LongTensor([1] * batch_size).to(device) #probs = t.Tensor().to(device) with t.no_grad(): for i in range(self.max_length): try: token_mask = Masker.get_mask(length) self_attention_mask = Masker.get_dot_mask(token_mask, token_mask) self_attention_mask = Masker.get_forward_mask(self_attention_mask) dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask) last_prob, last_token_id = self.decode_step( token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask, topk=1, return_last=True) token_id = t.cat([token_id, last_token_id], dim=1) # print('concate, tokenid', token_id) #probs = t.cat([probs, last_prob], dim=1) for index, id in enumerate(last_token_id.squeeze(1)): if id != self.eos_id: length[index] += 1 except: #TODO: to be more consious break return token_id
def _prepare_token(self, token, token_length): input_token, output_token, token_length = self._rebuild_target( token, token_length) token_mask = Masker.get_mask(token_length) token_self_attention_mask = Masker.get_dot_mask(token_mask, token_mask) token_self_attention_mask = Masker.get_forward_mask( token_self_attention_mask) return input_token, output_token, token_length, token_mask, token_self_attention_mask
def _prepare_feature(self, feature, feature_length, restrict_left_length=None, restrict_right_length=None): """ do spec augment and build mask """ if self.enable_spec_augment: feature = self.spec_augment(feature, feature_length) feature_mask = Masker.get_mask(feature_length) self_attention_mask = Masker.get_dot_mask(feature_mask, feature_mask) self_attention_mask = Masker.get_restricted_mask( self_attention_mask, restrict_left_length, restrict_right_length) return feature, feature_mask, self_attention_mask
def forward(self, feature, feature_length, ori_token, ori_token_length, cal_ce_loss=True): # t.cuda.empty_cache() feature, feature_mask, feature_self_attention_mask = self._prepare_feature( feature, feature_length, restrict_left_length=self.restrict_left_length, restrict_right_length=self.restrict_right_length) # input_token, output_token, token_length, token_mask, token_self_attention_mask, swich_target = self._prepare_token( ori_token, ori_token_length) # spec_feature = self.spec_encoder(feature, feature_mask, feature_self_attention_mask) # spec_output = self.encoder_linear(spec_feature) # dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask) output, swich = self.token_decoder(input_token, spec_feature, token_mask, token_self_attention_mask, dot_attention_mask) switch_loss = self.switch_loss(swich, swich_target) if cal_ce_loss: ce_loss = self.cal_ce_loss(output, output_token, type='lbce') else: ce_loss = None return output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss
def _prepare_token(self, token, token_length): """ build target and mask """ input_token, output_token, token_length = self._rebuild_target( token, token_length) token_mask = Masker.get_mask(token_length) token_self_attention_mask = Masker.get_dot_mask(token_mask, token_mask) token_self_attention_mask = Masker.get_forward_mask( token_self_attention_mask) switch = t.ones_like(output_token, device=token.device).long() # eng = 1 switch.masked_fill_(output_token.eq(0), 0) # pad=0 switch.masked_fill_((output_token.ge(12) & output_token.le(4211)), 2) # ch = 2 switch.masked_fill_((output_token.ge(1) & output_token.le(10)), 3) # other = 3 return input_token, output_token, token_length, token_mask, token_self_attention_mask, switch
def beam_search_decode(self, encoder_output, feature_mask, beam_size, best_k=5, lp_eps=0.0): batch_size, feature_length, _ = encoder_output.size() device = encoder_output.device self.beam_steper = BeamSteper( batch_size=batch_size, beam_size=beam_size, bos_id=self.bos_id, eos_id=self.eos_id, vocab_size=self.vocab_size, device=device, k_best=best_k, lp_eps=lp_eps ) beam_feature_mask = feature_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_size*beam_size, -1) beam_encoder_output = encoder_output.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size*beam_size, feature_length, -1) with t.no_grad(): for i in range(self.max_length): if i == 0: token_id = self.beam_steper.get_first_step_token() length = self.beam_steper.get_first_step_length() token_mask = Masker.get_mask(length) self_attention_mask = Masker.get_dot_mask(token_mask, token_mask) self_attention_mask = Masker.get_forward_mask(self_attention_mask) dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask) last_prob = self.beam_decode_step( token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask) if_continue = self.beam_steper.first_step(last_prob) if not if_continue: break else: token_id = self.beam_steper.token_container token_id = token_id.view(batch_size * beam_size, -1) length = self.beam_steper.length_container length = length.view(batch_size * beam_size) token_mask = Masker.get_mask(length) self_attention_mask = Masker.get_dot_mask(token_mask, token_mask) self_attention_mask = Masker.get_forward_mask(self_attention_mask) dot_attention_mask = Masker.get_dot_mask(token_mask, beam_feature_mask) last_prob = self.beam_decode_step( token_id, beam_encoder_output, token_mask, self_attention_mask, dot_attention_mask ) if_continue = self.beam_steper.step(last_prob.view(batch_size, beam_size, -1)) if not if_continue: break output_token = self.beam_steper.batch_best_saver.batch return output_token
def recognize(self, feature, feature_length, beam=5, penalty=0, ctc_weight=0.3, maxlenratio=0, minlenratio=0, char_list=None, rnnlm=None, lm_weight=0.1, nbest=1): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) ,length (B) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ assert feature.size(0) == 1 enc_output, feature_mask = self._encode(feature, feature_length) if ctc_weight > 0.0: lpz = t.nn.functional.log_softmax(self.encoder_linear(enc_output), -1).squeeze(0) else: lpz = None h = enc_output.squeeze(0) # print('input lengths: ' + str(h.size(0))) # preprare sos y = self.vocab.bos_id vy = h.new_zeros(1).long() if maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(maxlenratio * h.size(0))) minlen = int(minlenratio * h.size(0)) # print('max output length: ' + str(maxlen)) # print('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), self.vocab.blank_id, self.vocab.eos_id, np) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six for i in six.moves.range(maxlen): # print('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = Masker.get_mask(t.LongTensor([i + 1])) ys_self_attention_mask = Masker.get_dot_mask(ys_mask, ys_mask) ys_self_attention_mask = Masker.get_forward_mask( ys_self_attention_mask) dot_attention_mask = Masker.get_dot_mask(ys_mask, feature_mask) ys = t.tensor(hyp['yseq']).unsqueeze(0) local_att_scores = self.token_decoder.forward_one_step( ys, enc_output, ys_mask, ys_self_attention_mask, dot_attention_mask)[0] local_att_scores = t.nn.functional.log_softmax( local_att_scores, -1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = t.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * t.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = t.topk(local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = t.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept # print('number of pruned hypothes: ' + str(len(hyps))) # if char_list is not None: # print( # 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: # print('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.vocab.eos_id) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.vocab.eos_id: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection # if end_detect(ended_hyps, i) and maxlenratio == 0.0: # print('end detected at %d', i) # # break hyps = remained_hyps # if len(hyps) > 0: # # pass # print('remeined hypothes: ' + str(len(hyps))) # else: # print('no hypothesis. Finish decoding.') # # break # if char_list is not None: # for hyp in hyps: # print( # 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) # # print('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # check number of hypotheis if len(nbest_hyps) == 0: # print('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally return None # recog_args = Namespace(**vars(recog_args)) # recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) # return self.recognize(x, recog_args, char_list, rnnlm) # print('total log probability: ' + str(nbest_hyps[0]['score'])) # print('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def _prepare_feature(self, feature, feature_length): if self.enable_spec_augment: feature = self.spec_augment(feature, feature_length) feature_mask = Masker.get_mask(feature_length) self_attention_mask = Masker.get_dot_mask(feature_mask, feature_mask) return feature, feature_mask, self_attention_mask