Example #1
0
    def calculate_all_attentions(self, inputs, targets, input_sizes,
                                 target_sizes):
        '''E2E attention calculation

        :param list data: list of dicts of the input (B)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
         :rtype: float ndarray
        '''
        torch.set_grad_enabled(False)
        xpad = to_cuda(self, inputs)
        ilens = to_cuda(self, input_sizes)
        ys = []
        offset = 0
        for size in target_sizes:
            ys.append(targets[offset:offset + size])
            offset += size
        ys = [to_cuda(self, y) for y in ys]
        hpad, hlens = self.enc(xpad, ilens)

        # decoder
        att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys)

        torch.set_grad_enabled(True)

        return att_ws
 def forward(self, state, x):
     if state is None:
         state = {
             'F_h': to_cuda(self, self.zero_state(x.size(0), self.F_size)),
             'F_c': to_cuda(self, self.zero_state(x.size(0), self.F_size)),
             'S_h': to_cuda(self, self.zero_state(x.size(0), self.S_size)),
             'S_c': to_cuda(self, self.zero_state(x.size(0), self.S_size))
         }
     inputs = self.embed(x)
     inputs = self.d0(inputs)   
     F_h = state['F_h']
     F_c = state['F_c']              
     F_h_new, F_c_new = self.fast_cells[0](inputs, (F_h, F_c))  
     F_h, F_c = zoneout(F_h_new, F_c_new, F_h, F_c, self.zoneout_keep_h, self.zoneout_keep_c, self.training)   
     F_output_drop = self.d1(F_h)
     
     S_h = state['S_h']
     S_c = state['S_c']  
     S_h_new, S_c_new = self.slow_cell(F_output_drop, (S_h, S_c))
     S_h, S_c = zoneout(S_h_new, S_c_new, S_h, S_c, self.zoneout_keep_h, self.zoneout_keep_c, self.training)   
     S_output_drop = self.d2(S_h)
 
     F_h_new, F_c_new  = self.fast_cells[1](S_output_drop, (F_h, F_c))
     F_h, F_c = zoneout(F_h_new, F_c_new, F_h, F_c, self.zoneout_keep_h, self.zoneout_keep_c, self.training)       
     for i in range(2, self.fast_layers):
         F_h_new, F_c_new = self.fast_cells[i](F_h * 0.0, (F_h, F_c))
         F_h, F_c = zoneout(F_h_new, F_c_new, F_h, F_c, self.zoneout_keep_h, self.zoneout_keep_c, self.training)
     
     F_output_drop = self.d3(F_h)
     logits = self.out(F_output_drop)
     state = {'F_h': F_h, 'F_c': F_c, 'S_h': S_h, 'S_c': S_c}    
     return state, logits
    def calculate_all_specgram(self, mix_inputs, mix_log_inputs, input_sizes):
        torch.set_grad_enabled(False)
        mix_inputs = to_cuda(self, mix_inputs)
        mix_log_inputs = to_cuda(self, mix_log_inputs)
        ilens = to_cuda(self, input_sizes)

        if self.enhance_type == 'blstm':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'blstmp':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'unet_128':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'unet_256':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'vggblstmp':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
            xs, hlens = self.enc2(xs, hlens)
        elif self.enhance_type == 'vggblstm':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
            xs, hlens = self.enc2(xs, hlens)
        else:
            logging.error(
                "Error: need to specify an appropriate enhance archtecture")
            sys.exit()

        linear_out = self.fc(xs)
        out = F.sigmoid(linear_out)
        enhanced_out = out * mix_inputs
        torch.set_grad_enabled(True)

        return enhanced_out
 def forward(self, data, fbank_cmvn=None, scheduled_sample_rate=0.0):
                     
     utt_ids, spk_ids, inputs, targets, input_sizes, target_sizes = data
     inputs = to_cuda(self, inputs)
     ilens = to_cuda(self, input_sizes)
     fbank_cmvn = to_cuda(self, fbank_cmvn)
     fbank_features = self.fbank_model(inputs)
     if fbank_cmvn is not None:
         fbank_features = (fbank_features + fbank_cmvn[0, :]) * fbank_cmvn[1, :]
     loss_ctc, loss_att, acc = self.e2e(fbank_features, targets, input_sizes, target_sizes, scheduled_sample_rate)
 
     return loss_ctc, loss_att, acc
Example #5
0
 def forward(self, xs, fft_cmvn=None):
     '''FbankModel forward
     :param xs:
     :return:
     '''
     ##self.fc = torch.exp(self.fc)
     xs = to_cuda(self, xs)
     xs[xs <= 1e-7] = 1e-7
     out = torch.log(xs)
     if fft_cmvn is not None:
         fft_cmvn = to_cuda(self, fft_cmvn)
         out = (out + fft_cmvn[0, :]) * fft_cmvn[1, :]
     return out
Example #6
0
 def forward(self, state, x):
     if state is None:
         state = {
             'c1': to_cuda(self, self.zero_state(x.size(0))),
             'h1': to_cuda(self, self.zero_state(x.size(0))),
             'c2': to_cuda(self, self.zero_state(x.size(0))),
             'h2': to_cuda(self, self.zero_state(x.size(0)))
         }
     h0 = self.embed(x)
     h1, c1 = self.l1(self.d0(h0), (state['h1'], state['c1']))
     h2, c2 = self.l2(self.d1(h1), (state['h2'], state['c2']))
     y = self.lo(self.d2(h2))
     state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2}
     return state, y
Example #7
0
    def forward(self, hs_pad, hlens, ys_pad):
        '''CTC forward

        :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
        :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        '''
        # TODO(kan-bayashi): need to make more smart way
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys

        self.loss = None
        hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
        olens = torch.from_numpy(
            np.fromiter((x.size(0) for x in ys), dtype=np.int32))

        # zero padding for hs
        ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))

        # zero padding for ys
        ys_true = torch.cat(ys).cpu().int()  # batch x olen

        # get length info
        ##logging.info(self.__class__.__name__ + ' input lengths:  ' + ''.join(str(hlens).split('\n')))
        ##logging.info(self.__class__.__name__ + ' output lengths: ' + ''.join(str(olens).split('\n')))

        # get ctc loss
        # expected shape of seqLength x batchSize x alphabet_size
        ys_hat = ys_hat.transpose(0, 1)
        self.loss = to_cuda(self, self.loss_fn(ys_hat, ys_true, hlens, olens))
        ##logging.info('ctc loss:' + str(float(self.loss)))

        return self.loss
Example #8
0
    def recognize(self, x, recog_args, char_list, rnnlm=None, fstlm=None):
        '''E2E beam search

        :param x:
        :param recog_args:
        :param char_list:
        :return:
        '''
        prev = self.training
        self.eval()
        
        # subsample frame
        ##x = x[::self.subsample[0], :]
        ilen = [x.shape[1]]
        h = to_cuda(self, x)

        # 1. encoder
        # make a utt list (1) to use the same interface for encoder
        h, _ = self.enc(h, ilen)

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            lpz = self.ctc.log_softmax(h).data[0]
        else:
            lpz = None

        # 2. decoder
        # decode the first utterance
        y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm, fstlm)

        if prev:
            self.train()
        return y
Example #9
0
 def forward(self, xs, fbank_cmvn=None):
     '''FbankModel forward
     :param xs:
     :return:
     '''
     ##self.fc = torch.exp(self.fc)
     xs = to_cuda(self, xs)
     xs = (xs**2)
     n, t = xs.size(0), xs.size(1)
     xs = xs.view(n * t, -1)
     xs = torch.mm(xs, self.fc)
     out = xs.view(n, t, -1)
     out[out <= 1e-7] = 1e-7
     out = torch.log(out)
     if fbank_cmvn is not None:
         fbank_cmvn = to_cuda(self, fbank_cmvn)
         out = (out + fbank_cmvn[0, :]) * fbank_cmvn[1, :]
     return out
    def calculate_all_attentions(self, data, fbank_cmvn=None):
        '''E2E attention calculation

        :param list data: list of dicts of the input (B)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
         :rtype: float ndarray
        '''
        torch.set_grad_enabled(False)

        utt_ids, spk_ids, inputs, targets, input_sizes, target_sizes = data
        inputs = to_cuda(self, inputs)
        fbank_cmvn = to_cuda(self, fbank_cmvn)
        fbank_features = self.fbank_model(inputs)
        fbank_features = (fbank_features + fbank_cmvn[0, :]) * fbank_cmvn[1, :]
        data = (utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes)
        
        att_ws = self.e2e.calculate_all_attentions(data)
        return att_ws
    def recognize(self, x, fbank_cmvn, recog_args, char_list, rnnlm=None, kenlm=None):
        '''E2E beam search

        :param x:
        :param recog_args:
        :param char_list:
        :return:
        '''
        prev = self.training
        self.eval()
        
        x = to_cuda(self, x)
        fbank_cmvn = to_cuda(self, fbank_cmvn)
        fbank_features = self.fbank_model(x)
        fbank_features = (fbank_features + fbank_cmvn[0, :]) * fbank_cmvn[1, :]   
        y = self.e2e.recognize(fbank_features, recog_args, char_list, rnnlm, kenlm)
        
        if prev:
            self.train()
        return y
Example #12
0
    def forward(self,
                inputs,
                targets,
                input_sizes,
                target_sizes,
                scheduled_sampling_rate=0.0):
        '''E2E forward

        :param data:
        :return:
        '''
        xpad = to_cuda(self, inputs)
        ilens = to_cuda(self, input_sizes)
        ys = []
        offset = 0
        for size in target_sizes:
            ys.append(targets[offset:offset + size])
            offset += size
        ys = [to_cuda(self, y) for y in ys]

        # 1. encoder
        #xpad = pad_list(hs, 0.0)
        hpad, hlens = self.enc(xpad, ilens)

        # # 3. CTC loss
        if self.mtlalpha == 0:
            loss_ctc = None
        else:
            loss_ctc = self.ctc(hpad, hlens, ys)

        # 4. attention loss
        if self.mtlalpha == 1:
            loss_att = None
            acc = None
            space_acc = None
        else:
            loss_att, acc = self.dec(hpad, hlens, ys, scheduled_sampling_rate)

        return loss_ctc, loss_att, acc
Example #13
0
    def forward(self,
                mix_inputs,
                mix_log_inputs,
                input_sizes,
                clean_inputs=None,
                cos_angles=None):
        '''Encoder forward
        :param xs:
        :param ilens:
        :return:
        '''
        mix_inputs = to_cuda(self, mix_inputs)
        mix_log_inputs = to_cuda(self, mix_log_inputs)
        ilens = to_cuda(self, input_sizes)

        if self.enhance_type == 'blstm':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'blstmp':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'unet_128':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'unet_256':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
        elif self.enhance_type == 'vggblstmp':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
            xs, hlens = self.enc2(xs, hlens)
        elif self.enhance_type == 'vggblstm':
            xs, hlens = self.enc1(mix_log_inputs, ilens)
            xs, hlens = self.enc2(xs, hlens)
        else:
            logging.error(
                "Error: need to specify an appropriate enhance archtecture")
            sys.exit()

        if self.enhance_type == 'unet_128' or self.enhance_type == 'unet_256':
            linear_out = xs.squeeze(1)
        else:
            linear_out = self.fc(xs)
        out = torch.sigmoid(linear_out)
        mask = to_cuda(self, torch.ByteTensor(out.size()).fill_(0))
        for i, length in enumerate(ilens):
            length = length.item()
            if (mask[i].size(0) - length) > 0:
                mask[i].narrow(0, length, mask[i].size(0) - length).fill_(1)
        out = out.masked_fill(mask, 0)
        enhance_out = out * mix_inputs

        if clean_inputs is not None:
            clean_inputs = to_cuda(self, clean_inputs)
            cos_angles = to_cuda(self, cos_angles)
            ##loss = F.mse_loss(enhance_out, clean_inputs * cos_angles, size_average=False)
            loss = F.l1_loss(enhance_out,
                             clean_inputs * cos_angles,
                             size_average=False)
            loss /= torch.sum(ilens).float()
            return loss, enhance_out
        else:
            return enhance_out
Example #14
0
 def forward(self, input):
     input = to_cuda(self, input)
     if len(input.shape) == 3:
         input = input.unsqueeze(1)  
     return self.model(input)	
Example #15
0
    def forward(self, hpad, hlen, ys, scheduled_sampling_rate):
        '''Decoder forward

        :param hs:
        :param ys:
        :return:
        '''
        hpad = mask_by_length(hpad, hlen, 0)
        hlen = list(map(int, hlen))

        self.loss = None
        # prepare input and output word sequences with sos/eos IDs
        eos = Variable(ys[0].data.new([self.eos]))
        sos = Variable(ys[0].data.new([self.sos]))
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        pad_ys_in = pad_list(ys_in, self.eos)
        pad_ys_out = pad_list(ys_out, self.ignore_id)
        # get dim, length info
        batch = pad_ys_out.size(0)
        olength = pad_ys_out.size(1)
        ##logging.info(self.__class__.__name__ + ' input lengths:  ' + str(hlen))
        ##logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out]))

        # initialization
        c_list = [self.zero_state(hpad)]
        z_list = [self.zero_state(hpad)]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(hpad))
            z_list.append(self.zero_state(hpad))
        att_w = None
        z_all = []
        y_all = []
        self.att.reset()  # reset pre-computation of h

        # pre-computation of embedding
        eys = self.embed(pad_ys_in)  # utt x olen x zdim

        rnnlm_state_prev = None
        # loop for an output sequence
        for i in six.moves.range(olength):
            att_c, att_w = self.att(hpad, hlen, z_list[0], att_w)
            if random.random() < scheduled_sampling_rate and i > 0:
                topv, topi = y_i.topk(1)
                topi = topi.squeeze(1)
                ey_top = self.embed(topi)  # utt x zdim
                ey = torch.cat((ey_top, att_c), dim=1)  # utt x (zdim + hdim)
            else:
                topi = pad_ys_in[:, i]
                ey = torch.cat((eys[:, i, :], att_c),
                               dim=1)  # utt x (zdim + hdim)
            z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0]))
            for l in six.moves.range(1, self.dlayers):
                z_list[l], c_list[l] = self.decoder[l](z_list[l - 1],
                                                       (z_list[l], c_list[l]))

            if self.fusion == 'deep_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = rnnlm_state['h2']
                gi = F.sigmoid(self.gate_linear(lm_state))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            elif self.fusion == 'cold_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = F.relu(self.lm_linear(lm_scores))
                gi = F.sigmoid(
                    self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1)))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            else:
                output_in = z_list[-1]
            y_i = self.output(output_in)
            y_all.append(y_i)
            z_all.append(z_list[-1])

        y_all = torch.stack(y_all, dim=0).transpose(0, 1).contiguous().view(
            batch * olength, -1)
        self.loss = F.cross_entropy(y_all,
                                    pad_ys_out.view(-1),
                                    ignore_index=self.ignore_id,
                                    size_average=True)
        # -1: eos, which is removed in the loss computation
        self.loss *= (np.mean([len(x) for x in ys_in]) - 1)

        acc = th_accuracy(y_all, pad_ys_out, ignore_label=self.ignore_id)
        if self.labeldist is not None:
            if self.vlabeldist is None:
                self.vlabeldist = to_cuda(
                    self, Variable(torch.from_numpy(self.labeldist)))
            loss_reg = -torch.sum(
                (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1),
                dim=0) / len(ys_in)
            self.loss = (
                1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg

        return self.loss, acc
Example #16
0
    def recognize_beam(self,
                       h,
                       lpz,
                       recog_args,
                       char_list,
                       rnnlm=None,
                       fstlm=None):
        '''beam search implementation

        :param Variable h:
        :param Namespace recog_args:
        :param char_list:
        :return:
        '''
        logging.info('input lengths: ' + str(h.size(0)))
        # initialization
        c_list = [self.zero_state(h.unsqueeze(0))]
        z_list = [self.zero_state(h.unsqueeze(0))]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(h.unsqueeze(0)))
            z_list.append(self.zero_state(h.unsqueeze(0)))
        a = None
        self.att.reset()  # reset pre-computation of h

        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprate sos
        y = self.sos
        vy = h.new_zeros(1).long()

        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
        minlen = int(recog_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {
                'score': 0.0,
                'yseq': [y],
                'c_prev': c_list,
                'z_prev': z_list,
                'a_prev': a,
                'rnnlm_prev': None
            }
        else:
            hyp = {
                'score': 0.0,
                'yseq': [y],
                'c_prev': c_list,
                'z_prev': z_list,
                'a_prev': a
            }
        if fstlm is not None:
            hyp['fstlm_prev'] = None

        if lpz is not None:
            ctc_prefix_score = CTCPrefixScore(lpz.cpu().numpy(), 0, self.eos,
                                              np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        rnnlm_state_prev = None
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]
                ey = self.embed(vy)  # utt list (1) x zdim
                ey.unsqueeze(0)
                att_c, att_w = self.att(h.unsqueeze(0), [h.size(0)],
                                        hyp['z_prev'][0], hyp['a_prev'])
                ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
                z_list[0], c_list[0] = self.decoder[0](
                    ey, (hyp['z_prev'][0], hyp['c_prev'][0]))
                for l in six.moves.range(1, self.dlayers):
                    z_list[l], c_list[l] = self.decoder[l](
                        z_list[l - 1], (hyp['z_prev'][l], hyp['c_prev'][l]))

                if self.fusion == 'deep_fusion' and self.rnnlm is not None:
                    rnnlm_state, lm_scores = self.rnnlm.predict(
                        rnnlm_state_prev, vy)
                    lm_state = rnnlm_state['h2']
                    gi = F.sigmoid(self.gate_linear(lm_state))
                    output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                    rnnlm_state_prev = rnnlm_state
                elif self.fusion == 'cold_fusion' and self.rnnlm is not None:
                    rnnlm_state, lm_scores = self.rnnlm.predict(
                        rnnlm_state_prev, vy)
                    lm_state = F.relu(self.lm_linear(lm_scores))
                    gi = F.sigmoid(
                        self.gate_linear(
                            torch.cat((lm_state, z_list[-1]), dim=1)))
                    output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                    rnnlm_state_prev = rnnlm_state
                else:
                    output_in = z_list[-1]

                # get nbest local scores and their ids
                local_att_scores = F.log_softmax(self.output(output_in),
                                                 dim=1).data
                if fstlm:
                    '''local_best_scores, local_best_ids = torch.topk(local_att_scores, kenlm_beam, dim=1)
                    kenlm_state, kenlm_scores = kenlm.predict(hyp['kenlm_prev'], local_best_ids[0])                
                    local_scores = local_att_scores[:, local_best_ids[0]] + recog_args.lm_weight * torch.from_numpy(kenlm_scores)
                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]'''
                    fstlm_state, local_lm_scores = fstlm.predict(
                        hyp['fstlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                elif rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(
                        hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * to_cuda(self, torch.from_numpy(ctc_scores - hyp['ctc_score_prev']))
                    if rnnlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[
                            0]]
                    elif fstlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[
                            0]]
                    ##print('vy', vy)
                    ##print('local_att_scores', local_scores, local_scores.shape)
                    ##print('local_lm_scores', recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]])
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    ##if not kenlm:
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    # [:] is needed!
                    new_hyp['z_prev'] = z_list[:]
                    new_hyp['c_prev'] = c_list[:]
                    new_hyp['a_prev'] = att_w[:]
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if fstlm:
                        new_hyp['fstlm_prev'] = fstlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            logging.debug(
                'best hypo: ' +
                ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                logging.debug(
                    'hypo: ' +
                    ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'],
            reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' +
                     str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))

        # remove sos
        return nbest_hyps