コード例 #1
0
class Solver(BaseSolver):
    ''' Solver for training language models'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_loss = 10

    def fetch_data(self, data):
        ''' Move data to device, insert <sos> and compute text seq. length'''
        txt = torch.cat((torch.zeros(
            (data.shape[0], 1), dtype=torch.long), data),
                        dim=1).to(self.device)
        txt_len = torch.sum(data != 0, dim=-1)
        return txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
            load_textset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device)
        self.model = Prediction(self.vocab_size,
                                **self.config['model']).to(self.device)
        self.rnnlm = RNNLM(self.vocab_size,
                           **self.config['model']).to(self.device)

        self.verbose(self.rnnlm.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            list(self.model.parameters()) + list(self.rnnlm.parameters()),
            **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        self.timer.set()

        while self.step < self.max_step:
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                self.optimizer.pre_step(self.step)

                # Fetch data
                txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                outputs, hidden = self.model(txt[:, :-1], txt_len)
                pred = self.rnnlm(outputs)

                # Compute all objectives
                lm_loss = self.seq_loss(pred.view(-1, self.vocab_size),
                                        txt[:, 1:].reshape(-1))
                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(lm_loss)
                self.step += 1

                # Logger
                if self.step % self.PROGRESS_STEP == 0:
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(lm_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('entropy', {'tr': lm_loss})
                    self.write_log('perplexity',
                                   {'tr': torch.exp(lm_loss).cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                self.timer.set()
                if self.step > self.max_step:
                    break
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        self.rnnlm.eval()
        dev_loss = []

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                outputs, hidden = self.model(txt[:, :-1], txt_len)
                pred = self.rnnlm(outputs)
            lm_loss = self.seq_loss(pred.view(-1, self.vocab_size),
                                    txt[:, 1:].reshape(-1))
            dev_loss.append(lm_loss)

        # Ckpt if performance improves
        dev_loss = sum(dev_loss) / len(dev_loss)
        dev_ppx = torch.exp(dev_loss).cpu().item()
        if dev_loss < self.best_loss:
            self.best_loss = dev_loss
            self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx)
        self.write_log('entropy', {'dv': dev_loss})
        self.write_log('perplexity', {'dv': dev_ppx})

        # Show some example of last batch on tensorboard
        for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
            if self.step == 1:
                self.write_log('true_text{}'.format(i),
                               self.tokenizer.decode(txt[i].tolist()))
            self.write_log(
                'pred_text{}'.format(i),
                self.tokenizer.decode(pred[i].argmax(dim=-1).tolist()))

        # Resume training
        self.model.train()
        self.rnnlm.train()
コード例 #2
0
class BeamDecoder(nn.Module):
    ''' Beam decoder for ASR '''

    def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio,
                 lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0):
        super().__init__()
        # Setup
        self.beam_size = beam_size
        self.min_len_ratio = min_len_ratio
        self.max_len_ratio = max_len_ratio
        self.asr = asr

        # ToDo : implement pure ctc decode
        # assert self.asr.enable_att

        # Additional decoding modules
        self.apply_ctc = ctc_weight > 0
        if self.apply_ctc:
            assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
            self.ctc_w = ctc_weight
            self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)

        self.apply_lm = lm_weight > 0
        if self.apply_lm:
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
            self.lm.load_state_dict(torch.load(
                self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

        self.apply_emb = emb_decoder is not None
        if self.apply_emb:
            self.emb_decoder = emb_decoder

    def create_msg(self):
        msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format(
            self.beam_size, self.min_len_ratio, self.max_len_ratio)]
        if self.apply_ctc:
            msg.append(
                '           |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w))
        if self.apply_lm:
            msg.append('           |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format(
                self.lm_w, self.lm_path))
        if self.apply_emb:
            msg.append('           |Joint Emb. decoding enabled \t| weight = {:.2f}'.format(
                self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item()))

        return msg

    def forward(self, audio_feature, feature_len):
        # Init.
        assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search"
        batch_size = audio_feature.shape[0]
        device = audio_feature.device
        dec_state = self.asr.decoder.init_state(
                batch_size)                           # Init zero states
        self.asr.attention.reset_mem()            # Flush attention mem
        # Max output len set w/ hyper param.
        max_output_len = int(
            np.ceil(feature_len.cpu().item()*self.max_len_ratio))
        # Min output len set w/ hyper param.
        min_output_len = int(
            np.ceil(feature_len.cpu().item()*self.min_len_ratio))
        # Store attention map if location-aware
        store_att = self.asr.attention.mode == 'loc'
        prev_token = torch.zeros(
            (batch_size, 1), dtype=torch.long, device=device)     # Start w/ <sos>
        # Cache of beam search
        final_hypothesis, next_top_hypothesis = [], []
        # Incase ctc is disabled
        ctc_state, ctc_prob, candidates, lm_state = None, None, None, None

        # Encode
        encode_feature, encode_len = self.asr.encoder(
            audio_feature, feature_len)
        
        # CTC decoding
        if self.apply_ctc:
            ctc_output = F.log_softmax(
                self.asr.ctc_layer(encode_feature), dim=-1)
            # print(ctc_output.shape) torch.Size([1, 155, 45])
            ctc_prefix = CTCPrefixScore(ctc_output)
            ctc_state = ctc_prefix.init_state()
            if self.ctc_w == 1.0:
                # output_seq = ctc_output[0].argmax(dim=-1)
                # hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq,
                              # output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0,
                              # ctc_state=ctc_state, att_map=None)]
                # return hypothesis
                
                # custom beam for pure CTC beam decode
                import collections
                def make_new_beam():
                    fn = lambda: (-float('inf'), -float('inf'), 0, False, None)
                    return collections.defaultdict(fn)
                
                def log_sum(*args):
                    if all(a == -float('inf') for a in args):
                        return -float('inf')
                    a_max = max(args)
                    lsp = torch.tensor(0, dtype=torch.float32)
                    for a in args:
                        lsp += torch.exp(a - a_max)
                    return a_max + torch.log(lsp)
                
                def get_final_prob(*args):
                    prob_blank, prob_non_blank, prob_text, applied_lm, lm_state = args
                    prob_total = log_sum(prob_blank, prob_non_blank)
                    return prob_total + prob_text
                
                # init beam                
                # prefix, (p_b, p_nb, p_text, applied_lm, lm_state)
                beam = [(tuple(), (0, -float('inf'), 0, False, None))]
                
                # assume batch size is always 1 when decoding
                ctc_output = ctc_output.squeeze(0)
                
                # iterate all timestamp 
                for t in range(ctc_output.shape[0]):
                    # iterate all vocab candidate
                    next_beam = make_new_beam()
                    for vocab in range(ctc_output.shape[1]):
                        p = ctc_output[t, vocab]
                        for prefix, (prob_blank, prob_non_blank, prob_text, applied_lm, lm_state) in beam:
                            
                            # blank case (input: -) *a => *a  
                            if vocab == 0:
                                next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix]
                                next_prob_blank = log_sum(next_prob_blank, prob_blank + p, prob_non_blank + p)
                                # text not changed, so the prob_text and lm_state also the same
                                next_beam[prefix] = (next_prob_blank, next_prob_non_blank, prob_text, applied_lm, lm_state)
                                continue
                            
                            # non blank case
                            prev_vocab = prefix[-1] if prefix else None
                            next_prefix = prefix + (vocab, )
                            next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[next_prefix]
                            if vocab != prev_vocab:
                                # (input: b) *a => *ab
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p, prob_non_blank + p)
                            else:
                                # (input: a) *a- => *aa
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p)
                            
                            
                            # apply language model
                            if self.apply_lm and not next_applied_lm:
                                if prev_vocab is None:
                                    prev_vocab = 0
                                lm_input = torch.LongTensor([prev_vocab]).to(device)
                                lm_input = lm_input.unsqueeze(0)
                                lm_output, next_lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=lm_state)
                                lm_output = lm_output.squeeze(0)
                                next_prob_text = prob_text + self.lm_w * lm_output.log_softmax(dim=-1).squeeze(0)[vocab]
                                next_applied_lm = True
                            
                            next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state)
                            
                            # (input: a) *a => *a
                            if vocab == prev_vocab:
                                # merging all prob of *a to single one
                                next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix]
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_non_blank + p)
                                next_beam[prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state)
                                
                    beam = sorted(next_beam.items(), key=lambda x: get_final_prob(*x[1]), reverse=True)
                    beam = beam[:self.ctc_beam_size]                    
                
                output_seq = torch.LongTensor(beam[0][0])
                # print(output_seq)
                hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq,
                               output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0,
                               ctc_state=ctc_state, att_map=None)]
                
                return hypothesis

        # Start w/ empty hypothesis
        prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[],
                                          output_scores=[], lm_state=None, ctc_prob=0,
                                          ctc_state=ctc_state, att_map=None)]
        # Attention decoding
        for t in range(max_output_len):
            for hypothesis in prev_top_hypothesis:
                # Resume previous step
                prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
                    device)
                self.asr.set_state(prev_dec_state, prev_attn)

                # Normal asr forward
                attn, context = self.asr.attention(
                    self.asr.decoder.get_query(), encode_feature, encode_len)
                asr_prev_token = self.asr.pre_embed(prev_token)
                decoder_input = torch.cat([asr_prev_token, context], dim=-1)
                cur_prob, d_state = self.asr.decoder(decoder_input)

                # Embedding fusion (output shape 1xV)
                if self.apply_emb:
                    _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False)
                else:
                    cur_prob = F.log_softmax(cur_prob, dim=-1)

                # Perform CTC prefix scoring on limited candidates (else OOM easily)
                if self.apply_ctc:
                    # TODO : Check the performance drop for computing part of candidates only
                    _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1)
                    candidates = ctc_candidates.cpu().tolist()
                    ctc_prob, ctc_state = ctc_prefix.cheap_compute(
                        hypothesis.outIndex, prev_ctc_state, candidates)
                    # TODO : study why ctc_char (slightly) > 0 sometimes
                    ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device)

                    # Combine CTC score and Attention score (HACK: focus on candidates, block others)
                    hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO)
                    for idx, char in enumerate(candidates):
                        hack_ctc_char[0, char] = ctc_char[idx]
                    cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char  # ctc_char
                    cur_prob[0, 0] = LOG_ZERO  # Hack to ignore <sos>

                # Joint RNN-LM decoding
                if self.apply_lm:
                    # assuming batch size always 1, resulting 1x1
                    lm_input = prev_token.unsqueeze(1)
                    lm_output, lm_state = self.lm(
                        lm_input, torch.ones([batch_size]), hidden=prev_lm_state)
                    # assuming batch size always 1,  resulting 1xV
                    lm_output = lm_output.squeeze(0)
                    cur_prob += self.lm_w*lm_output.log_softmax(dim=-1)

                # Beam search
                # Note: Ignored batch dim.
                topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
                prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None
                final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn,
                                                lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob,
                                                ctc_candidates=candidates)
                # Move complete hyps. out
                if final is not None and (t >= min_output_len):
                    final_hypothesis.append(final)
                    if self.beam_size == 1:
                        return final_hypothesis
                next_top_hypothesis.extend(top)

            # Sort for top N beams
            next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
            prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
            next_top_hypothesis = []

        # Rescore all hyp (finished/unfinished)
        final_hypothesis += prev_top_hypothesis
        final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)

        return final_hypothesis[:self.beam_size]
コード例 #3
0
class BeamDecoder(nn.Module):
    ''' Beam decoder for ASR '''
    def __init__(self,
                 asr,
                 emb_decoder,
                 beam_size,
                 min_len_ratio,
                 max_len_ratio,
                 lm_path='',
                 lm_config='',
                 lm_weight=0.0,
                 ctc_weight=0.0):
        super().__init__()
        # Setup
        self.beam_size = beam_size
        self.min_len_ratio = min_len_ratio
        self.max_len_ratio = max_len_ratio
        self.asr = asr

        # ToDo : implement pure ctc decode
        assert self.asr.enable_att

        # Additional decoding modules
        self.apply_ctc = ctc_weight > 0
        if self.apply_ctc:
            assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
            self.ctc_w = ctc_weight
            self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)

        self.apply_lm = lm_weight > 0
        if self.apply_lm:
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
            self.lm.load_state_dict(
                torch.load(self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

        self.apply_emb = emb_decoder is not None
        if self.apply_emb:
            self.emb_decoder = emb_decoder

    def create_msg(self):
        msg = [
            'Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format(
                self.beam_size, self.min_len_ratio, self.max_len_ratio)
        ]
        if self.apply_ctc:
            msg.append(
                '           |Joint CTC decoding enabled \t| weight = {:.2f}\t'.
                format(self.ctc_w))
        if self.apply_lm:
            msg.append(
                '           |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'
                .format(self.lm_w, self.lm_path))
        if self.apply_emb:
            msg.append(
                '           |Joint Emb. decoding enabled \t| weight = {:.2f}'.
                format(self.lm_w,
                       self.emb_decoder.fuse_lambda.mean().cpu().item()))

        return msg

    def forward(self, audio_feature, feature_len):
        # Init.
        assert audio_feature.shape[
            0] == 1, "Batchsize == 1 is required for beam search"
        batch_size = audio_feature.shape[0]
        device = audio_feature.device
        dec_state = self.asr.decoder.init_state(batch_size)  # Init zero states
        self.asr.attention.reset_mem()  # Flush attention mem
        # Max output len set w/ hyper param.
        max_output_len = int(
            np.ceil(feature_len.cpu().item() * self.max_len_ratio))
        # Min output len set w/ hyper param.
        min_output_len = int(
            np.ceil(feature_len.cpu().item() * self.min_len_ratio))
        # Store attention map if location-aware
        store_att = self.asr.attention.mode == 'loc'
        prev_token = torch.zeros((batch_size, 1),
                                 dtype=torch.long,
                                 device=device)  # Start w/ <sos>
        # Cache of beam search
        final_hypothesis, next_top_hypothesis = [], []
        # Incase ctc is disabled
        ctc_state, ctc_prob, candidates, lm_state = None, None, None, None

        # Encode
        encode_feature, encode_len = self.asr.encoder(audio_feature,
                                                      feature_len)

        # CTC decoding
        if self.apply_ctc:
            ctc_output = F.log_softmax(self.asr.ctc_layer(encode_feature),
                                       dim=-1)
            ctc_prefix = CTCPrefixScore(ctc_output)
            ctc_state = ctc_prefix.init_state()

        # Start w/ empty hypothesis
        prev_top_hypothesis = [
            Hypothesis(decoder_state=dec_state,
                       output_seq=[],
                       output_scores=[],
                       lm_state=None,
                       ctc_prob=0,
                       ctc_state=ctc_state,
                       att_map=None)
        ]
        # Attention decoding
        for t in range(max_output_len):
            for hypothesis in prev_top_hypothesis:  ## for each hypothesis, generate B top condidate
                # Resume previous step
                prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
                    device)
                self.asr.set_state(prev_dec_state, prev_attn)

                # Normal asr forward
                attn, context = self.asr.attention(
                    self.asr.decoder.get_query(), encode_feature, encode_len)
                asr_prev_token = self.asr.pre_embed(prev_token)
                decoder_input = torch.cat([asr_prev_token, context], dim=-1)
                cur_prob, d_state = self.asr.decoder(decoder_input)

                # Embedding fusion (output shape 1xV)
                if self.apply_emb:
                    _, cur_prob = self.emb_decoder(d_state,
                                                   cur_prob,
                                                   return_loss=False)
                else:
                    cur_prob = F.log_softmax(cur_prob, dim=-1)
                att_prob = cur_prob.squeeze(0)
                #print('att_prob', att_prob.shape) # att_prob torch.Size([31])

                # Perform CTC prefix scoring on limited candidates (else OOM easily)
                if self.apply_ctc:
                    # TODO : Check the performance drop for computing part of candidates only
                    _, ctc_candidates = cur_prob.squeeze(0).topk(
                        self.ctc_beam_size, dim=-1)
                    candidates = ctc_candidates.cpu().tolist()
                    ctc_prob, ctc_state = ctc_prefix.cheap_compute(
                        hypothesis.outIndex, prev_ctc_state, candidates)
                    # TODO : study why ctc_char (slightly) > 0 sometimes
                    ctc_char = torch.FloatTensor(
                        ctc_prob - hypothesis.ctc_prob).to(device)

                    # Combine CTC score and Attention score (HACK: focus on candidates, block others)
                    hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(
                        LOG_ZERO)
                    for idx, char in enumerate(candidates):
                        hack_ctc_char[0, char] = ctc_char[idx]
                    cur_prob = (
                        1 - self.ctc_w
                    ) * cur_prob + self.ctc_w * hack_ctc_char  # ctc_char
                    cur_prob[0, 0] = LOG_ZERO  # Hack to ignore <sos>

                # Joint RNN-LM decoding
                if self.apply_lm:
                    # assuming batch size always 1, resulting 1x1
                    lm_input = prev_token.unsqueeze(1)
                    lm_output, lm_state = self.lm(lm_input,
                                                  torch.ones([batch_size]),
                                                  hidden=prev_lm_state)
                    # assuming batch size always 1,  resulting 1xV
                    lm_output = lm_output.squeeze(0)
                    cur_prob += self.lm_w * lm_output.log_softmax(dim=-1)
                    '''no otehr constraint to lengthen transcripy?'''

                # Beam search
                # Note: Ignored batch dim.
                topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
                #print(topv)
                #print(topi)
                prev_attn = self.asr.attention.att_layer.prev_att.cpu(
                ) if store_att else None
                final, top = hypothesis.addTopk(topi,
                                                topv,
                                                self.asr.decoder.get_state(),
                                                att_map=prev_attn,
                                                lm_state=lm_state,
                                                ctc_state=ctc_state,
                                                ctc_prob=ctc_prob,
                                                ctc_candidates=candidates,
                                                att_prob=att_prob)
                # top : new hypo

                # Move complete hyps. out
                # finish hypo (stop)
                if final is not None and (
                        t >=
                        min_output_len):  # if detect eos, final is not None
                    final_hypothesis.append(final)  # finish one beam
                    if self.beam_size == 1:
                        return final_hypothesis
                # keep finding candidate for hypo
                next_top_hypothesis.extend(
                    top
                )  ## collect each hypo's top b candidate, and later pick b top from b*b

            # Sort for top N beams
            next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
            prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
            next_top_hypothesis = []

        # Rescore all hyp (finished/unfinished)
        final_hypothesis += prev_top_hypothesis  # add the last one ?
        final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)

        return final_hypothesis[:self.beam_size]
コード例 #4
0
ファイル: ctc.py プロジェクト: zge/End-to-end-ASR-Pytorch
class CTCBeamDecoder(nn.Module):
    ''' Beam decoder for ASR (CTC only) '''
    def __init__(self,
                 asr,
                 vocab_range,
                 beam_size,
                 vocab_candidate,
                 lm_path='',
                 lm_config='',
                 lm_weight=0.0,
                 device=None):
        super().__init__()
        # Setup
        self.asr = asr
        self.vocab_range = vocab_range
        self.beam_size = beam_size
        self.vocab_cand = vocab_candidate
        assert self.vocab_cand <= len(self.vocab_range)

        assert self.asr.enable_ctc

        # Setup RNNLM
        self.apply_lm = lm_weight > 0
        self.lm_w = 0
        if self.apply_lm:
            self.device = device
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size,
                            **lm_config['model']).to(self.device)
            self.lm.load_state_dict(
                torch.load(self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

    def create_msg(self):
        msg = [
            'Decode spec| CTC decoding \t| Beam size = {} \t| LM weight = {}'.
            format(self.beam_size, self.lm_w)
        ]
        return msg

    def forward(self, feat, feat_len):
        # Init.
        assert feat.shape[0] == 1, "Batchsize == 1 is required for beam search"

        # Calculate CTC output probability
        ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.asr(feat, feat_len, 10)
        del encode_len, att_output, att_align, dec_state, feat_len
        ctc_output = F.log_softmax(ctc_output[0],
                                   dim=-1).cpu().detach().numpy()
        T = len(ctc_output)  # ctc_output = Pr(k,t|x) / dim: T x Vocab

        # Best W probable sequences
        B = [CTCHypothesis()]
        if self.apply_lm:
            # 0 == <sos> for RNNLM
            output, hidden = \
                self.lm(torch.zeros((1,1),dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), None)
            B[0].update_lm(
                (output).log_softmax(dim=-1).squeeze().cpu().numpy(), hidden)

        start = True
        for t in range(T):
            # greedily ignoring pads at the beginning of the sequence
            if np.argmax(ctc_output[t]) == 0 and start:
                continue
            else:
                start = False
            B_new = []
            for i in range(len(B)):  # For y in B
                B_i_new = copy.deepcopy(B[i])
                if B_i_new.get_len() > 0:  # If y is not empty
                    if B_i_new.y[-1] == 1:
                        # <eos> = 1 (reached the end)
                        B_new.append(B_i_new)
                        continue
                    B_i_new.update_Pr_nblank(ctc_output[t, B_i_new.y[-1]])
                    # Find the same prefix
                    for j in range(len(B)):
                        if i != j and B[j].check_same(B_i_new.y[:-1]):
                            lm_prob = 0.0
                            if self.apply_lm:
                                lm_prob = self.lm_w * B[j].lm_output[
                                    B_i_new.y[-1]]
                            B_i_new.update_Pr_nblank_prefix(
                                ctc_output[t, B_i_new.y[-1]],
                                B[j].Pr_y_t_blank, B[j].Pr_y_t_nblank, lm_prob)
                            break
                B_i_new.update_Pr_blank(ctc_output[t, 0])  # 0 == <pad>
                if self.apply_lm:
                    lm_hidden = B_i_new.lm_hidden
                    lm_probs = B_i_new.lm_output
                else:
                    lm_hidden = None
                    lm_probs = None

                # Sort the next possible output symbol by CTC (and LM) score
                if self.apply_lm:
                    ctc_vocab_cand = sorted(zip(
                        self.vocab_range, ctc_output[t, self.vocab_range] +
                        self.lm_w * lm_probs[self.vocab_range]),
                                            reverse=True,
                                            key=lambda x: x[1])
                else:
                    ctc_vocab_cand = sorted(zip(
                        self.vocab_range, ctc_output[t, self.vocab_range]),
                                            reverse=True,
                                            key=lambda x: x[1])
                # Select top K possible symbols to calculate the probabilities
                for j in range(self.vocab_cand):
                    # <pad>=0, <eos>=1, <unk>=2
                    k = ctc_vocab_cand[j][0]
                    # Pr(k,t|x)
                    hyp_yk = copy.deepcopy(B_i_new)
                    lm_prob = 0.0 if not self.apply_lm else self.lm_w * lm_probs[
                        k]
                    hyp_yk.add_token(k, ctc_output[t, k], lm_prob)
                    hyp_yk.updated_lm = False
                    B_new.append(hyp_yk)
                B_i_new.orig_backup(
                )  # Retrieve origin prob. before add_token()
                B_new.append(B_i_new)
            del B
            B = []

            # Remove duplicated sequences by sorting first (O(NlogN))
            B_new = sorted(B_new, key=lambda x: x.get_string())
            B.append(B_new[0])  # First Hyp always unique
            for i in range(1, len(B_new)):
                if B_new[i].check_same(B[-1].y):
                    # Next Hyp is duplicated, pick the higher one
                    if B_new[i].get_score() > B[-1].get_score():
                        B[-1] = B_new[i]
                    continue
                else:
                    # Next Hyp is different, hence valid
                    B.append(B_new[i])
            del B_new

            # Find top W possible sequences
            if t == T - 1:
                B = sorted(B, reverse=True, key=lambda x: x.get_final_score())
            else:
                B = sorted(B, reverse=True, key=lambda x: x.get_score())
            if len(B) > self.beam_size:
                B = B[:self.beam_size]

            # Update LM states
            if self.apply_lm and t < T - 1:
                for i in range(len(B)):
                    if B[i].get_len() > 0 and not B[i].updated_lm:
                        output, hidden = \
                            self.lm(B[i].y[-1] * torch.ones((1,1), dtype=torch.long).to(self.device),
                                torch.ones(1,dtype=torch.long).to(self.device), B[i].lm_hidden)
                        B[i].update_lm((output).log_softmax(
                            dim=-1).squeeze().cpu().numpy(), hidden)

        return [b.y for b in B]