def recognize_beam(self, h, lpz, recog_args, char_list): '''beam search implementation :param Variable h: :param Namespace recog_args: :param char_list: :return: ''' logging.info('input lengths: ' + str(h.size(0))) # initialization c_list = [self.zero_state(h.unsqueeze(0))] z_list = [self.zero_state(h.unsqueeze(0))] for l in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(h.unsqueeze(0))) z_list.append(self.zero_state(h.unsqueeze(0))) a = None self.att.reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos y = self.sos vy = Variable(h.data.new(1).zero_().long(), volatile=True) if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a } if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.numpy(), 0, self.eos, np) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] ey = self.embed(vy) # utt list (1) x zdim ey.unsqueeze(0) att_c, att_w = self.att(h.unsqueeze(0), [h.size(0)], hyp['z_prev'][0], hyp['a_prev']) ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) z_list[0], c_list[0] = self.decoder[0]( ey, (hyp['z_prev'][0], hyp['c_prev'][0])) for l in six.moves.range(1, self.dlayers): z_list[l], c_list[l] = self.decoder[l]( z_list[l - 1], (hyp['z_prev'][l], hyp['c_prev'][l])) # get nbest local scores and their ids local_att_scores = F.log_softmax(self.output(z_list[-1]), dim=1).data if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_att_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['z_prev'] = [z_list[0]] new_hyp['c_prev'] = [c_list[0]] for l in six.moves.range(1, self.dlayers): new_hyp['z_prev'].append(z_list[l]) new_hyp['c_prev'].append(c_list[l]) new_hyp['a_prev'] = att_w new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = local_best_ids[0, j] if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0] logging.info('total log probability: ' + str(best_hyp['score'])) logging.info('normalized log probability: ' + str(best_hyp['score'] / len(best_hyp['yseq']))) # remove sos return best_hyp['yseq'][1:]
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None): '''beam search implementation :param h: :param recog_args: :param char_list: :return: ''' logging.info('input lengths: ' + str(h.shape[0])) # initialization c_list = [None] # list of cell state of each layer z_list = [None] # list of hidden state of each layer for l in six.moves.range(1, self.dlayers): c_list.append(None) z_list.append(None) a = None self.att.reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos y = self.xp.full(1, self.sos, 'i') if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ey = self.embed(hyp['yseq'][i]) # utt list (1) x zdim att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev']) ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim) c_list[0], z_list[0] = self.lstm0(hyp['c_prev'][0], hyp['z_prev'][0], ey) for l in six.moves.range(1, self.dlayers): c_list[l], z_list[l] = self['lstm%d' % l]( hyp['c_prev'][l], hyp['z_prev'][l], z_list[l - 1]) # get nbest local scores and their ids local_att_scores = F.log_softmax(self.output(z_list[-1])).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], hyp['yseq'][i]) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score(hyp['yseq'], local_best_ids, hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \ + ctc_weight * (ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids] joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} # do not copy {z,c}_list directly new_hyp['z_prev'] = z_list[:] new_hyp['c_prev'] = c_list[:] new_hyp['a_prev'] = att_w new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full( 1, local_best_ids[j], 'i') if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) logging.debug('best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]).replace('<space>', ' ')) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.xp.full(1, self.eos, 'i')) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug('hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]).replace('<space>', ' ')) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warn('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def recognize_beam(self, h, recog_args, char_list): '''beam search implementation :param h: :param recog_args: :param char_list: :return: ''' logging.info('input lengths: ' + str(h.shape[0])) # initialization c_list = [None] # list of cell state of each layer z_list = [None] # list of hidden state of each layer for l in six.moves.range(1, self.dlayers): c_list.append(None) z_list.append(None) a = None self.att.reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty # preprate sos y = self.xp.full(1, self.sos, 'i') if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a } hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ey = self.embed(hyp['yseq'][i]) # utt list (1) x zdim att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev']) ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim) c_list[0], z_list[0] = self.lstm0(hyp['c_prev'][0], hyp['z_prev'][0], ey) for l in six.moves.range(1, self.dlayers): c_list[l], z_list[l] = self['lstm%d' % l](hyp['c_prev'][l], hyp['z_prev'][l], z_list[l - 1]) hyp['c_prev'] = c_list hyp['z_prev'] = z_list hyp['a_prev'] = att_w # get nbest local scores and their ids local_scores = F.log_softmax(self.output(z_list[-1])).data local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] for j in six.moves.range(beam): new_hyp = {} # do not copy {z,c}_list directly new_hyp['z_prev'] = [z_list[0]] new_hyp['c_prev'] = [c_list[0]] for l in six.moves.range(1, self.dlayers): new_hyp['z_prev'].append(z_list[l]) new_hyp['c_prev'].append(c_list[l]) new_hyp['a_prev'] = att_w new_hyp['score'] = hyp['score'] + \ local_scores[0, local_best_ids[j]] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full( 1, local_best_ids[j], 'i') # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) logging.debug('best hypo: ' + ''.join( [char_list[int(x)] for x in hyps[0]['yseq'][1:]]).replace('<space>', ' ')) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.xp.full(1, self.eos, 'i')) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug('hypo: ' + ''.join( [char_list[int(x)] for x in hyp['yseq'][1:]]).replace('<space>', ' ')) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0] logging.info('total log probability: ' + str(best_hyp['score'])) logging.info('normalized log probability: ' + str(best_hyp['score'] / len(best_hyp['yseq']))) # remove sos return best_hyp['yseq'][1:]