def beam_decode(self, audio_feature, decode_step, state_len, decode_beam_size): '''beam decode returns top N hyps for each input sequence''' assert not self.training assert audio_feature.shape[0] == 1 self.decode_beam_size = decode_beam_size ctc_beam_size = int(CTC_BEAM_RATIO * self.decode_beam_size) # Encode encode_feature, encode_len = self.encoder(audio_feature, state_len) if decode_step == 0: decode_step = int(encode_len[0]) # Init. cur_device = next(self.decoder.parameters()).device ctc_output = None ctc_state = None ctc_prob = 0.0 ctc_candidates = [] candidates = None att_output = None att_maps = None lm_hidden = None # CTC based decoding if self.joint_ctc: ctc_output = F.log_softmax(self.ctc_layer(encode_feature), dim=-1) ctc_prefix = CTCPrefixScore(ctc_output) ctc_state = ctc_prefix.init_state() # Attention based decoding if self.joint_att: # Store attention map if location-aware store_att = self.attention.mode == 'loc' # Init (init char = <SOS>, reset all rnn state and cell) self.decoder.init_rnn(encode_feature) self.attention.reset_enc_mem() last_char = self.embed( torch.zeros((1), dtype=torch.long).to(cur_device)) last_char_idx = torch.LongTensor([[0]]) # beam search init final_outputs, prev_top_outputs, next_top_outputs = [], [], [] prev_top_outputs.append( Hypothesis(self.decoder.hidden_state, self.embed, output_seq=[], output_scores=[], lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None) ) # WIERD BUG here if all args. are not passed... # Decode for t in range(decode_step): for prev_output in prev_top_outputs: # Attention self.decoder.hidden_state = prev_output.decoder_state self.attention.prev_att = None if prev_output.att_map is None else prev_output.att_map.to( cur_device) attention_score, context = self.attention( self.decoder.state_list[0], encode_feature, encode_len) decoder_input = torch.cat([prev_output.last_char, context], dim=-1) dec_out = self.decoder(decoder_input) cur_char = F.log_softmax(self.char_trans(dec_out), dim=-1) # Perform CTC prefix scoring on limited candidates (else OOM easily) if self.joint_ctc: # TODO : Check the performance drop for computing part of candidates only _, ctc_candidates = cur_char.topk(ctc_beam_size) candidates = list(ctc_candidates[0].cpu().numpy()) #ctc_prob, ctc_state = ctc_prefix.full_compute(prev_output.outIndex,prev_output.ctc_state,candidates) ctc_prob, ctc_state = ctc_prefix.cheap_compute( prev_output.outIndex, prev_output.ctc_state, candidates) # TODO : study why ctc_char (slightly) > 0 sometimes ctc_char = torch.FloatTensor( ctc_prob - prev_output.ctc_prob).to(cur_device) # Combine CTC score and Attention score (focus on candidates) hack_ctc_char = torch.zeros_like(cur_char).data.fill_( -1000000.0) for idx, char in enumerate(candidates): hack_ctc_char[0, char] = ctc_char[idx] cur_char = ( 1 - self.ctc_weight ) * cur_char + self.ctc_weight * hack_ctc_char #ctc_char# cur_char[0, 0] = -10000000.0 # Hack to ignore <sos> # Joint RNN-LM decoding if self.decode_lm_weight > 0: last_char_idx = prev_output.last_char_idx.to( cur_device) lm_hidden, lm_output = self.rnn_lm( last_char_idx, [1], prev_output.lm_state) cur_char += self.decode_lm_weight * F.log_softmax( lm_output.squeeze(1), dim=-1) # Beam search topv, topi = cur_char.topk(self.decode_beam_size) prev_att_map = self.attention.prev_att.clone().detach( ).cpu() if store_att else None final, top = prev_output.addTopk(topi, topv, self.decoder.hidden_state, att_map=prev_att_map, lm_state=lm_hidden, ctc_state=ctc_state, ctc_prob=ctc_prob, ctc_candidates=candidates) # Move complete hyps. out if final is not None: final_outputs.append(final) if self.decode_beam_size == 1: return final_outputs next_top_outputs.extend(top) # Sort for top N beams next_top_outputs.sort(key=lambda o: o.avgScore(), reverse=True) prev_top_outputs = next_top_outputs[:self.decode_beam_size] next_top_outputs = [] final_outputs += prev_top_outputs final_outputs.sort(key=lambda o: o.avgScore(), reverse=True) return final_outputs[:self.decode_beam_size]
def forward(self, audio_feature, feature_len): # Init. assert audio_feature.shape[ 0] == 1, "Batchsize == 1 is required for beam search" batch_size = audio_feature.shape[0] device = audio_feature.device dec_state = self.asr.decoder.init_state(batch_size) # Init zero states self.asr.attention.reset_mem() # Flush attention mem # Max output len set w/ hyper param. max_output_len = int( np.ceil(feature_len.cpu().item() * self.max_len_ratio)) # Min output len set w/ hyper param. min_output_len = int( np.ceil(feature_len.cpu().item() * self.min_len_ratio)) # Store attention map if location-aware store_att = self.asr.attention.mode == 'loc' prev_token = torch.zeros((batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos> # Cache of beam search final_hypothesis, next_top_hypothesis = [], [] # Incase ctc is disabled ctc_state, ctc_prob, candidates, lm_state = None, None, None, None # Encode encode_feature, encode_len = self.asr.encoder(audio_feature, feature_len) # CTC decoding if self.apply_ctc: ctc_output = F.log_softmax(self.asr.ctc_layer(encode_feature), dim=-1) ctc_prefix = CTCPrefixScore(ctc_output) ctc_state = ctc_prefix.init_state() # Start w/ empty hypothesis prev_top_hypothesis = [ Hypothesis(decoder_state=dec_state, output_seq=[], output_scores=[], lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None) ] # Attention decoding for t in range(max_output_len): for hypothesis in prev_top_hypothesis: ## for each hypothesis, generate B top condidate # Resume previous step prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( device) self.asr.set_state(prev_dec_state, prev_attn) # Normal asr forward attn, context = self.asr.attention( self.asr.decoder.get_query(), encode_feature, encode_len) asr_prev_token = self.asr.pre_embed(prev_token) decoder_input = torch.cat([asr_prev_token, context], dim=-1) cur_prob, d_state = self.asr.decoder(decoder_input) # Embedding fusion (output shape 1xV) if self.apply_emb: _, cur_prob = self.emb_decoder(d_state, cur_prob, return_loss=False) else: cur_prob = F.log_softmax(cur_prob, dim=-1) att_prob = cur_prob.squeeze(0) #print('att_prob', att_prob.shape) # att_prob torch.Size([31]) # Perform CTC prefix scoring on limited candidates (else OOM easily) if self.apply_ctc: # TODO : Check the performance drop for computing part of candidates only _, ctc_candidates = cur_prob.squeeze(0).topk( self.ctc_beam_size, dim=-1) candidates = ctc_candidates.cpu().tolist() ctc_prob, ctc_state = ctc_prefix.cheap_compute( hypothesis.outIndex, prev_ctc_state, candidates) # TODO : study why ctc_char (slightly) > 0 sometimes ctc_char = torch.FloatTensor( ctc_prob - hypothesis.ctc_prob).to(device) # Combine CTC score and Attention score (HACK: focus on candidates, block others) hack_ctc_char = torch.zeros_like(cur_prob).data.fill_( LOG_ZERO) for idx, char in enumerate(candidates): hack_ctc_char[0, char] = ctc_char[idx] cur_prob = ( 1 - self.ctc_w ) * cur_prob + self.ctc_w * hack_ctc_char # ctc_char cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos> # Joint RNN-LM decoding if self.apply_lm: # assuming batch size always 1, resulting 1x1 lm_input = prev_token.unsqueeze(1) lm_output, lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=prev_lm_state) # assuming batch size always 1, resulting 1xV lm_output = lm_output.squeeze(0) cur_prob += self.lm_w * lm_output.log_softmax(dim=-1) '''no otehr constraint to lengthen transcripy?''' # Beam search # Note: Ignored batch dim. topv, topi = cur_prob.squeeze(0).topk(self.beam_size) #print(topv) #print(topi) prev_attn = self.asr.attention.att_layer.prev_att.cpu( ) if store_att else None final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, ctc_candidates=candidates, att_prob=att_prob) # top : new hypo # Move complete hyps. out # finish hypo (stop) if final is not None and ( t >= min_output_len): # if detect eos, final is not None final_hypothesis.append(final) # finish one beam if self.beam_size == 1: return final_hypothesis # keep finding candidate for hypo next_top_hypothesis.extend( top ) ## collect each hypo's top b candidate, and later pick b top from b*b # Sort for top N beams next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) prev_top_hypothesis = next_top_hypothesis[:self.beam_size] next_top_hypothesis = [] # Rescore all hyp (finished/unfinished) final_hypothesis += prev_top_hypothesis # add the last one ? final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) return final_hypothesis[:self.beam_size]
def forward(self, audio_feature, feature_len): # Init. assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search" batch_size = audio_feature.shape[0] device = audio_feature.device dec_state = self.asr.decoder.init_state( batch_size) # Init zero states self.asr.attention.reset_mem() # Flush attention mem # Max output len set w/ hyper param. max_output_len = int( np.ceil(feature_len.cpu().item()*self.max_len_ratio)) # Min output len set w/ hyper param. min_output_len = int( np.ceil(feature_len.cpu().item()*self.min_len_ratio)) # Store attention map if location-aware store_att = self.asr.attention.mode == 'loc' prev_token = torch.zeros( (batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos> # Cache of beam search final_hypothesis, next_top_hypothesis = [], [] # Incase ctc is disabled ctc_state, ctc_prob, candidates, lm_state = None, None, None, None # Encode encode_feature, encode_len = self.asr.encoder( audio_feature, feature_len) # CTC decoding if self.apply_ctc: ctc_output = F.log_softmax( self.asr.ctc_layer(encode_feature), dim=-1) # print(ctc_output.shape) torch.Size([1, 155, 45]) ctc_prefix = CTCPrefixScore(ctc_output) ctc_state = ctc_prefix.init_state() if self.ctc_w == 1.0: # output_seq = ctc_output[0].argmax(dim=-1) # hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq, # output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0, # ctc_state=ctc_state, att_map=None)] # return hypothesis # custom beam for pure CTC beam decode import collections def make_new_beam(): fn = lambda: (-float('inf'), -float('inf'), 0, False, None) return collections.defaultdict(fn) def log_sum(*args): if all(a == -float('inf') for a in args): return -float('inf') a_max = max(args) lsp = torch.tensor(0, dtype=torch.float32) for a in args: lsp += torch.exp(a - a_max) return a_max + torch.log(lsp) def get_final_prob(*args): prob_blank, prob_non_blank, prob_text, applied_lm, lm_state = args prob_total = log_sum(prob_blank, prob_non_blank) return prob_total + prob_text # init beam # prefix, (p_b, p_nb, p_text, applied_lm, lm_state) beam = [(tuple(), (0, -float('inf'), 0, False, None))] # assume batch size is always 1 when decoding ctc_output = ctc_output.squeeze(0) # iterate all timestamp for t in range(ctc_output.shape[0]): # iterate all vocab candidate next_beam = make_new_beam() for vocab in range(ctc_output.shape[1]): p = ctc_output[t, vocab] for prefix, (prob_blank, prob_non_blank, prob_text, applied_lm, lm_state) in beam: # blank case (input: -) *a => *a if vocab == 0: next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix] next_prob_blank = log_sum(next_prob_blank, prob_blank + p, prob_non_blank + p) # text not changed, so the prob_text and lm_state also the same next_beam[prefix] = (next_prob_blank, next_prob_non_blank, prob_text, applied_lm, lm_state) continue # non blank case prev_vocab = prefix[-1] if prefix else None next_prefix = prefix + (vocab, ) next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[next_prefix] if vocab != prev_vocab: # (input: b) *a => *ab next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p, prob_non_blank + p) else: # (input: a) *a- => *aa next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p) # apply language model if self.apply_lm and not next_applied_lm: if prev_vocab is None: prev_vocab = 0 lm_input = torch.LongTensor([prev_vocab]).to(device) lm_input = lm_input.unsqueeze(0) lm_output, next_lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=lm_state) lm_output = lm_output.squeeze(0) next_prob_text = prob_text + self.lm_w * lm_output.log_softmax(dim=-1).squeeze(0)[vocab] next_applied_lm = True next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state) # (input: a) *a => *a if vocab == prev_vocab: # merging all prob of *a to single one next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix] next_prob_non_blank = log_sum(next_prob_non_blank, prob_non_blank + p) next_beam[prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state) beam = sorted(next_beam.items(), key=lambda x: get_final_prob(*x[1]), reverse=True) beam = beam[:self.ctc_beam_size] output_seq = torch.LongTensor(beam[0][0]) # print(output_seq) hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq, output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None)] return hypothesis # Start w/ empty hypothesis prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[], output_scores=[], lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None)] # Attention decoding for t in range(max_output_len): for hypothesis in prev_top_hypothesis: # Resume previous step prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( device) self.asr.set_state(prev_dec_state, prev_attn) # Normal asr forward attn, context = self.asr.attention( self.asr.decoder.get_query(), encode_feature, encode_len) asr_prev_token = self.asr.pre_embed(prev_token) decoder_input = torch.cat([asr_prev_token, context], dim=-1) cur_prob, d_state = self.asr.decoder(decoder_input) # Embedding fusion (output shape 1xV) if self.apply_emb: _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False) else: cur_prob = F.log_softmax(cur_prob, dim=-1) # Perform CTC prefix scoring on limited candidates (else OOM easily) if self.apply_ctc: # TODO : Check the performance drop for computing part of candidates only _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1) candidates = ctc_candidates.cpu().tolist() ctc_prob, ctc_state = ctc_prefix.cheap_compute( hypothesis.outIndex, prev_ctc_state, candidates) # TODO : study why ctc_char (slightly) > 0 sometimes ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device) # Combine CTC score and Attention score (HACK: focus on candidates, block others) hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO) for idx, char in enumerate(candidates): hack_ctc_char[0, char] = ctc_char[idx] cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char # ctc_char cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos> # Joint RNN-LM decoding if self.apply_lm: # assuming batch size always 1, resulting 1x1 lm_input = prev_token.unsqueeze(1) lm_output, lm_state = self.lm( lm_input, torch.ones([batch_size]), hidden=prev_lm_state) # assuming batch size always 1, resulting 1xV lm_output = lm_output.squeeze(0) cur_prob += self.lm_w*lm_output.log_softmax(dim=-1) # Beam search # Note: Ignored batch dim. topv, topi = cur_prob.squeeze(0).topk(self.beam_size) prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, ctc_candidates=candidates) # Move complete hyps. out if final is not None and (t >= min_output_len): final_hypothesis.append(final) if self.beam_size == 1: return final_hypothesis next_top_hypothesis.extend(top) # Sort for top N beams next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) prev_top_hypothesis = next_top_hypothesis[:self.beam_size] next_top_hypothesis = [] # Rescore all hyp (finished/unfinished) final_hypothesis += prev_top_hypothesis final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) return final_hypothesis[:self.beam_size]