Esempio n. 1
0
    def beam_decode(self, audio_feature, decode_step, state_len,
                    decode_beam_size):
        '''beam decode returns top N hyps for each input sequence'''
        assert not self.training
        assert audio_feature.shape[0] == 1
        self.decode_beam_size = decode_beam_size
        ctc_beam_size = int(CTC_BEAM_RATIO * self.decode_beam_size)

        # Encode
        encode_feature, encode_len = self.encoder(audio_feature, state_len)
        if decode_step == 0:
            decode_step = int(encode_len[0])

        # Init.
        cur_device = next(self.decoder.parameters()).device
        ctc_output = None
        ctc_state = None
        ctc_prob = 0.0
        ctc_candidates = []
        candidates = None
        att_output = None
        att_maps = None
        lm_hidden = None

        # CTC based decoding
        if self.joint_ctc:
            ctc_output = F.log_softmax(self.ctc_layer(encode_feature), dim=-1)
            ctc_prefix = CTCPrefixScore(ctc_output)
            ctc_state = ctc_prefix.init_state()

        # Attention based decoding
        if self.joint_att:
            # Store attention map if location-aware
            store_att = self.attention.mode == 'loc'

            # Init (init char = <SOS>, reset all rnn state and cell)
            self.decoder.init_rnn(encode_feature)
            self.attention.reset_enc_mem()
            last_char = self.embed(
                torch.zeros((1), dtype=torch.long).to(cur_device))
            last_char_idx = torch.LongTensor([[0]])

            # beam search init
            final_outputs, prev_top_outputs, next_top_outputs = [], [], []
            prev_top_outputs.append(
                Hypothesis(self.decoder.hidden_state,
                           self.embed,
                           output_seq=[],
                           output_scores=[],
                           lm_state=None,
                           ctc_prob=0,
                           ctc_state=ctc_state,
                           att_map=None)
            )  # WIERD BUG here if all args. are not passed...
            # Decode
            for t in range(decode_step):
                for prev_output in prev_top_outputs:

                    # Attention
                    self.decoder.hidden_state = prev_output.decoder_state
                    self.attention.prev_att = None if prev_output.att_map is None else prev_output.att_map.to(
                        cur_device)
                    attention_score, context = self.attention(
                        self.decoder.state_list[0], encode_feature, encode_len)
                    decoder_input = torch.cat([prev_output.last_char, context],
                                              dim=-1)
                    dec_out = self.decoder(decoder_input)
                    cur_char = F.log_softmax(self.char_trans(dec_out), dim=-1)

                    # Perform CTC prefix scoring on limited candidates (else OOM easily)
                    if self.joint_ctc:
                        # TODO : Check the performance drop for computing part of candidates only
                        _, ctc_candidates = cur_char.topk(ctc_beam_size)
                        candidates = list(ctc_candidates[0].cpu().numpy())

                        #ctc_prob, ctc_state = ctc_prefix.full_compute(prev_output.outIndex,prev_output.ctc_state,candidates)
                        ctc_prob, ctc_state = ctc_prefix.cheap_compute(
                            prev_output.outIndex, prev_output.ctc_state,
                            candidates)

                        # TODO : study why ctc_char (slightly) > 0 sometimes
                        ctc_char = torch.FloatTensor(
                            ctc_prob - prev_output.ctc_prob).to(cur_device)

                        # Combine CTC score and Attention score (focus on candidates)
                        hack_ctc_char = torch.zeros_like(cur_char).data.fill_(
                            -1000000.0)
                        for idx, char in enumerate(candidates):
                            hack_ctc_char[0, char] = ctc_char[idx]
                        cur_char = (
                            1 - self.ctc_weight
                        ) * cur_char + self.ctc_weight * hack_ctc_char  #ctc_char#
                        cur_char[0, 0] = -10000000.0  # Hack to ignore <sos>

                    # Joint RNN-LM decoding
                    if self.decode_lm_weight > 0:
                        last_char_idx = prev_output.last_char_idx.to(
                            cur_device)
                        lm_hidden, lm_output = self.rnn_lm(
                            last_char_idx, [1], prev_output.lm_state)
                        cur_char += self.decode_lm_weight * F.log_softmax(
                            lm_output.squeeze(1), dim=-1)

                    # Beam search
                    topv, topi = cur_char.topk(self.decode_beam_size)
                    prev_att_map = self.attention.prev_att.clone().detach(
                    ).cpu() if store_att else None
                    final, top = prev_output.addTopk(topi,
                                                     topv,
                                                     self.decoder.hidden_state,
                                                     att_map=prev_att_map,
                                                     lm_state=lm_hidden,
                                                     ctc_state=ctc_state,
                                                     ctc_prob=ctc_prob,
                                                     ctc_candidates=candidates)
                    # Move complete hyps. out
                    if final is not None:
                        final_outputs.append(final)
                        if self.decode_beam_size == 1:
                            return final_outputs
                    next_top_outputs.extend(top)

                # Sort for top N beams
                next_top_outputs.sort(key=lambda o: o.avgScore(), reverse=True)
                prev_top_outputs = next_top_outputs[:self.decode_beam_size]
                next_top_outputs = []

            final_outputs += prev_top_outputs
            final_outputs.sort(key=lambda o: o.avgScore(), reverse=True)

        return final_outputs[:self.decode_beam_size]
Esempio n. 2
0
    def forward(self, audio_feature, feature_len):
        # Init.
        assert audio_feature.shape[
            0] == 1, "Batchsize == 1 is required for beam search"
        batch_size = audio_feature.shape[0]
        device = audio_feature.device
        dec_state = self.asr.decoder.init_state(batch_size)  # Init zero states
        self.asr.attention.reset_mem()  # Flush attention mem
        # Max output len set w/ hyper param.
        max_output_len = int(
            np.ceil(feature_len.cpu().item() * self.max_len_ratio))
        # Min output len set w/ hyper param.
        min_output_len = int(
            np.ceil(feature_len.cpu().item() * self.min_len_ratio))
        # Store attention map if location-aware
        store_att = self.asr.attention.mode == 'loc'
        prev_token = torch.zeros((batch_size, 1),
                                 dtype=torch.long,
                                 device=device)  # Start w/ <sos>
        # Cache of beam search
        final_hypothesis, next_top_hypothesis = [], []
        # Incase ctc is disabled
        ctc_state, ctc_prob, candidates, lm_state = None, None, None, None

        # Encode
        encode_feature, encode_len = self.asr.encoder(audio_feature,
                                                      feature_len)

        # CTC decoding
        if self.apply_ctc:
            ctc_output = F.log_softmax(self.asr.ctc_layer(encode_feature),
                                       dim=-1)
            ctc_prefix = CTCPrefixScore(ctc_output)
            ctc_state = ctc_prefix.init_state()

        # Start w/ empty hypothesis
        prev_top_hypothesis = [
            Hypothesis(decoder_state=dec_state,
                       output_seq=[],
                       output_scores=[],
                       lm_state=None,
                       ctc_prob=0,
                       ctc_state=ctc_state,
                       att_map=None)
        ]
        # Attention decoding
        for t in range(max_output_len):
            for hypothesis in prev_top_hypothesis:  ## for each hypothesis, generate B top condidate
                # Resume previous step
                prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
                    device)
                self.asr.set_state(prev_dec_state, prev_attn)

                # Normal asr forward
                attn, context = self.asr.attention(
                    self.asr.decoder.get_query(), encode_feature, encode_len)
                asr_prev_token = self.asr.pre_embed(prev_token)
                decoder_input = torch.cat([asr_prev_token, context], dim=-1)
                cur_prob, d_state = self.asr.decoder(decoder_input)

                # Embedding fusion (output shape 1xV)
                if self.apply_emb:
                    _, cur_prob = self.emb_decoder(d_state,
                                                   cur_prob,
                                                   return_loss=False)
                else:
                    cur_prob = F.log_softmax(cur_prob, dim=-1)
                att_prob = cur_prob.squeeze(0)
                #print('att_prob', att_prob.shape) # att_prob torch.Size([31])

                # Perform CTC prefix scoring on limited candidates (else OOM easily)
                if self.apply_ctc:
                    # TODO : Check the performance drop for computing part of candidates only
                    _, ctc_candidates = cur_prob.squeeze(0).topk(
                        self.ctc_beam_size, dim=-1)
                    candidates = ctc_candidates.cpu().tolist()
                    ctc_prob, ctc_state = ctc_prefix.cheap_compute(
                        hypothesis.outIndex, prev_ctc_state, candidates)
                    # TODO : study why ctc_char (slightly) > 0 sometimes
                    ctc_char = torch.FloatTensor(
                        ctc_prob - hypothesis.ctc_prob).to(device)

                    # Combine CTC score and Attention score (HACK: focus on candidates, block others)
                    hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(
                        LOG_ZERO)
                    for idx, char in enumerate(candidates):
                        hack_ctc_char[0, char] = ctc_char[idx]
                    cur_prob = (
                        1 - self.ctc_w
                    ) * cur_prob + self.ctc_w * hack_ctc_char  # ctc_char
                    cur_prob[0, 0] = LOG_ZERO  # Hack to ignore <sos>

                # Joint RNN-LM decoding
                if self.apply_lm:
                    # assuming batch size always 1, resulting 1x1
                    lm_input = prev_token.unsqueeze(1)
                    lm_output, lm_state = self.lm(lm_input,
                                                  torch.ones([batch_size]),
                                                  hidden=prev_lm_state)
                    # assuming batch size always 1,  resulting 1xV
                    lm_output = lm_output.squeeze(0)
                    cur_prob += self.lm_w * lm_output.log_softmax(dim=-1)
                    '''no otehr constraint to lengthen transcripy?'''

                # Beam search
                # Note: Ignored batch dim.
                topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
                #print(topv)
                #print(topi)
                prev_attn = self.asr.attention.att_layer.prev_att.cpu(
                ) if store_att else None
                final, top = hypothesis.addTopk(topi,
                                                topv,
                                                self.asr.decoder.get_state(),
                                                att_map=prev_attn,
                                                lm_state=lm_state,
                                                ctc_state=ctc_state,
                                                ctc_prob=ctc_prob,
                                                ctc_candidates=candidates,
                                                att_prob=att_prob)
                # top : new hypo

                # Move complete hyps. out
                # finish hypo (stop)
                if final is not None and (
                        t >=
                        min_output_len):  # if detect eos, final is not None
                    final_hypothesis.append(final)  # finish one beam
                    if self.beam_size == 1:
                        return final_hypothesis
                # keep finding candidate for hypo
                next_top_hypothesis.extend(
                    top
                )  ## collect each hypo's top b candidate, and later pick b top from b*b

            # Sort for top N beams
            next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
            prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
            next_top_hypothesis = []

        # Rescore all hyp (finished/unfinished)
        final_hypothesis += prev_top_hypothesis  # add the last one ?
        final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)

        return final_hypothesis[:self.beam_size]
Esempio n. 3
0
    def forward(self, audio_feature, feature_len):
        # Init.
        assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search"
        batch_size = audio_feature.shape[0]
        device = audio_feature.device
        dec_state = self.asr.decoder.init_state(
                batch_size)                           # Init zero states
        self.asr.attention.reset_mem()            # Flush attention mem
        # Max output len set w/ hyper param.
        max_output_len = int(
            np.ceil(feature_len.cpu().item()*self.max_len_ratio))
        # Min output len set w/ hyper param.
        min_output_len = int(
            np.ceil(feature_len.cpu().item()*self.min_len_ratio))
        # Store attention map if location-aware
        store_att = self.asr.attention.mode == 'loc'
        prev_token = torch.zeros(
            (batch_size, 1), dtype=torch.long, device=device)     # Start w/ <sos>
        # Cache of beam search
        final_hypothesis, next_top_hypothesis = [], []
        # Incase ctc is disabled
        ctc_state, ctc_prob, candidates, lm_state = None, None, None, None

        # Encode
        encode_feature, encode_len = self.asr.encoder(
            audio_feature, feature_len)
        
        # CTC decoding
        if self.apply_ctc:
            ctc_output = F.log_softmax(
                self.asr.ctc_layer(encode_feature), dim=-1)
            # print(ctc_output.shape) torch.Size([1, 155, 45])
            ctc_prefix = CTCPrefixScore(ctc_output)
            ctc_state = ctc_prefix.init_state()
            if self.ctc_w == 1.0:
                # output_seq = ctc_output[0].argmax(dim=-1)
                # hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq,
                              # output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0,
                              # ctc_state=ctc_state, att_map=None)]
                # return hypothesis
                
                # custom beam for pure CTC beam decode
                import collections
                def make_new_beam():
                    fn = lambda: (-float('inf'), -float('inf'), 0, False, None)
                    return collections.defaultdict(fn)
                
                def log_sum(*args):
                    if all(a == -float('inf') for a in args):
                        return -float('inf')
                    a_max = max(args)
                    lsp = torch.tensor(0, dtype=torch.float32)
                    for a in args:
                        lsp += torch.exp(a - a_max)
                    return a_max + torch.log(lsp)
                
                def get_final_prob(*args):
                    prob_blank, prob_non_blank, prob_text, applied_lm, lm_state = args
                    prob_total = log_sum(prob_blank, prob_non_blank)
                    return prob_total + prob_text
                
                # init beam                
                # prefix, (p_b, p_nb, p_text, applied_lm, lm_state)
                beam = [(tuple(), (0, -float('inf'), 0, False, None))]
                
                # assume batch size is always 1 when decoding
                ctc_output = ctc_output.squeeze(0)
                
                # iterate all timestamp 
                for t in range(ctc_output.shape[0]):
                    # iterate all vocab candidate
                    next_beam = make_new_beam()
                    for vocab in range(ctc_output.shape[1]):
                        p = ctc_output[t, vocab]
                        for prefix, (prob_blank, prob_non_blank, prob_text, applied_lm, lm_state) in beam:
                            
                            # blank case (input: -) *a => *a  
                            if vocab == 0:
                                next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix]
                                next_prob_blank = log_sum(next_prob_blank, prob_blank + p, prob_non_blank + p)
                                # text not changed, so the prob_text and lm_state also the same
                                next_beam[prefix] = (next_prob_blank, next_prob_non_blank, prob_text, applied_lm, lm_state)
                                continue
                            
                            # non blank case
                            prev_vocab = prefix[-1] if prefix else None
                            next_prefix = prefix + (vocab, )
                            next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[next_prefix]
                            if vocab != prev_vocab:
                                # (input: b) *a => *ab
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p, prob_non_blank + p)
                            else:
                                # (input: a) *a- => *aa
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p)
                            
                            
                            # apply language model
                            if self.apply_lm and not next_applied_lm:
                                if prev_vocab is None:
                                    prev_vocab = 0
                                lm_input = torch.LongTensor([prev_vocab]).to(device)
                                lm_input = lm_input.unsqueeze(0)
                                lm_output, next_lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=lm_state)
                                lm_output = lm_output.squeeze(0)
                                next_prob_text = prob_text + self.lm_w * lm_output.log_softmax(dim=-1).squeeze(0)[vocab]
                                next_applied_lm = True
                            
                            next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state)
                            
                            # (input: a) *a => *a
                            if vocab == prev_vocab:
                                # merging all prob of *a to single one
                                next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix]
                                next_prob_non_blank = log_sum(next_prob_non_blank, prob_non_blank + p)
                                next_beam[prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state)
                                
                    beam = sorted(next_beam.items(), key=lambda x: get_final_prob(*x[1]), reverse=True)
                    beam = beam[:self.ctc_beam_size]                    
                
                output_seq = torch.LongTensor(beam[0][0])
                # print(output_seq)
                hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq,
                               output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0,
                               ctc_state=ctc_state, att_map=None)]
                
                return hypothesis

        # Start w/ empty hypothesis
        prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[],
                                          output_scores=[], lm_state=None, ctc_prob=0,
                                          ctc_state=ctc_state, att_map=None)]
        # Attention decoding
        for t in range(max_output_len):
            for hypothesis in prev_top_hypothesis:
                # Resume previous step
                prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
                    device)
                self.asr.set_state(prev_dec_state, prev_attn)

                # Normal asr forward
                attn, context = self.asr.attention(
                    self.asr.decoder.get_query(), encode_feature, encode_len)
                asr_prev_token = self.asr.pre_embed(prev_token)
                decoder_input = torch.cat([asr_prev_token, context], dim=-1)
                cur_prob, d_state = self.asr.decoder(decoder_input)

                # Embedding fusion (output shape 1xV)
                if self.apply_emb:
                    _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False)
                else:
                    cur_prob = F.log_softmax(cur_prob, dim=-1)

                # Perform CTC prefix scoring on limited candidates (else OOM easily)
                if self.apply_ctc:
                    # TODO : Check the performance drop for computing part of candidates only
                    _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1)
                    candidates = ctc_candidates.cpu().tolist()
                    ctc_prob, ctc_state = ctc_prefix.cheap_compute(
                        hypothesis.outIndex, prev_ctc_state, candidates)
                    # TODO : study why ctc_char (slightly) > 0 sometimes
                    ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device)

                    # Combine CTC score and Attention score (HACK: focus on candidates, block others)
                    hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO)
                    for idx, char in enumerate(candidates):
                        hack_ctc_char[0, char] = ctc_char[idx]
                    cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char  # ctc_char
                    cur_prob[0, 0] = LOG_ZERO  # Hack to ignore <sos>

                # Joint RNN-LM decoding
                if self.apply_lm:
                    # assuming batch size always 1, resulting 1x1
                    lm_input = prev_token.unsqueeze(1)
                    lm_output, lm_state = self.lm(
                        lm_input, torch.ones([batch_size]), hidden=prev_lm_state)
                    # assuming batch size always 1,  resulting 1xV
                    lm_output = lm_output.squeeze(0)
                    cur_prob += self.lm_w*lm_output.log_softmax(dim=-1)

                # Beam search
                # Note: Ignored batch dim.
                topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
                prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None
                final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn,
                                                lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob,
                                                ctc_candidates=candidates)
                # Move complete hyps. out
                if final is not None and (t >= min_output_len):
                    final_hypothesis.append(final)
                    if self.beam_size == 1:
                        return final_hypothesis
                next_top_hypothesis.extend(top)

            # Sort for top N beams
            next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
            prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
            next_top_hypothesis = []

        # Rescore all hyp (finished/unfinished)
        final_hypothesis += prev_top_hypothesis
        final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)

        return final_hypothesis[:self.beam_size]