def set_model(self): ''' Setup ASR model and optimizer ''' # Model # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.model = Prediction(self.vocab_size, **self.config['model']).to(self.device) self.rnnlm = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.verbose(self.rnnlm.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( list(self.model.parameters()) + list(self.rnnlm.parameters()), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step))
def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): super().__init__() # Setup self.beam_size = beam_size self.min_len_ratio = min_len_ratio self.max_len_ratio = max_len_ratio self.asr = asr # ToDo : implement pure ctc decode # assert self.asr.enable_att # Additional decoding modules self.apply_ctc = ctc_weight > 0 if self.apply_ctc: assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' self.ctc_w = ctc_weight self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) self.apply_lm = lm_weight > 0 if self.apply_lm: self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) self.lm.load_state_dict(torch.load( self.lm_path, map_location='cpu')['model']) self.lm.eval() self.apply_emb = emb_decoder is not None if self.apply_emb: self.emb_decoder = emb_decoder
def __init__(self, asr, vocab_range, beam_size, vocab_candidate, lm_path='', lm_config='', lm_weight=0.0, device=None): super().__init__() # Setup self.asr = asr self.vocab_range = vocab_range self.beam_size = beam_size self.vocab_cand = vocab_candidate assert self.vocab_cand <= len(self.vocab_range) assert self.asr.enable_ctc # Setup RNNLM self.apply_lm = lm_weight > 0 self.lm_w = 0 if self.apply_lm: self.device = device self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']).to(self.device) self.lm.load_state_dict( torch.load(self.lm_path, map_location='cpu')['model']) self.lm.eval()
class Solver(BaseSolver): ''' Solver for training language models''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_loss = 10 def fetch_data(self, data): ''' Move data to device, insert <sos> and compute text seq. length''' txt = torch.cat((torch.zeros( (data.shape[0], 1), dtype=torch.long), data), dim=1).to(self.device) txt_len = torch.sum(data != 0, dim=-1) return txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ load_textset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.model = Prediction(self.vocab_size, **self.config['model']).to(self.device) self.rnnlm = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.verbose(self.rnnlm.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( list(self.model.parameters()) + list(self.rnnlm.parameters()), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step)) def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) self.timer.set() while self.step < self.max_step: for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) # Compute all objectives lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) self.timer.cnt('fw') # Backprop grad_norm = self.backward(lm_loss) self.step += 1 # Logger if self.step % self.PROGRESS_STEP == 0: self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(lm_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('entropy', {'tr': lm_loss}) self.write_log('perplexity', {'tr': torch.exp(lm_loss).cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break self.log.close() def validate(self): # Eval mode self.model.eval() self.rnnlm.eval() dev_loss = [] for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) dev_loss.append(lm_loss) # Ckpt if performance improves dev_loss = sum(dev_loss) / len(dev_loss) dev_ppx = torch.exp(dev_loss).cpu().item() if dev_loss < self.best_loss: self.best_loss = dev_loss self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx) self.write_log('entropy', {'dv': dev_loss}) self.write_log('perplexity', {'dv': dev_ppx}) # Show some example of last batch on tensorboard for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'pred_text{}'.format(i), self.tokenizer.decode(pred[i].argmax(dim=-1).tolist())) # Resume training self.model.train() self.rnnlm.train()
class BeamDecoder(nn.Module): ''' Beam decoder for ASR ''' def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): super().__init__() # Setup self.beam_size = beam_size self.min_len_ratio = min_len_ratio self.max_len_ratio = max_len_ratio self.asr = asr # ToDo : implement pure ctc decode assert self.asr.enable_att # Additional decoding modules self.apply_ctc = ctc_weight > 0 if self.apply_ctc: assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' self.ctc_w = ctc_weight self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) self.apply_lm = lm_weight > 0 if self.apply_lm: self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) self.lm.load_state_dict( torch.load(self.lm_path, map_location='cpu')['model']) self.lm.eval() self.apply_emb = emb_decoder is not None if self.apply_emb: self.emb_decoder = emb_decoder def create_msg(self): msg = [ 'Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format( self.beam_size, self.min_len_ratio, self.max_len_ratio) ] if self.apply_ctc: msg.append( ' |Joint CTC decoding enabled \t| weight = {:.2f}\t'. format(self.ctc_w)) if self.apply_lm: msg.append( ' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}' .format(self.lm_w, self.lm_path)) if self.apply_emb: msg.append( ' |Joint Emb. decoding enabled \t| weight = {:.2f}'. format(self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item())) return msg def forward(self, audio_feature, feature_len): # Init. assert audio_feature.shape[ 0] == 1, "Batchsize == 1 is required for beam search" batch_size = audio_feature.shape[0] device = audio_feature.device dec_state = self.asr.decoder.init_state(batch_size) # Init zero states self.asr.attention.reset_mem() # Flush attention mem # Max output len set w/ hyper param. max_output_len = int( np.ceil(feature_len.cpu().item() * self.max_len_ratio)) # Min output len set w/ hyper param. min_output_len = int( np.ceil(feature_len.cpu().item() * self.min_len_ratio)) # Store attention map if location-aware store_att = self.asr.attention.mode == 'loc' prev_token = torch.zeros((batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos> # Cache of beam search final_hypothesis, next_top_hypothesis = [], [] # Incase ctc is disabled ctc_state, ctc_prob, candidates, lm_state = None, None, None, None # Encode encode_feature, encode_len = self.asr.encoder(audio_feature, feature_len) # CTC decoding if self.apply_ctc: ctc_output = F.log_softmax(self.asr.ctc_layer(encode_feature), dim=-1) ctc_prefix = CTCPrefixScore(ctc_output) ctc_state = ctc_prefix.init_state() # Start w/ empty hypothesis prev_top_hypothesis = [ Hypothesis(decoder_state=dec_state, output_seq=[], output_scores=[], lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None) ] # Attention decoding for t in range(max_output_len): for hypothesis in prev_top_hypothesis: ## for each hypothesis, generate B top condidate # Resume previous step prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( device) self.asr.set_state(prev_dec_state, prev_attn) # Normal asr forward attn, context = self.asr.attention( self.asr.decoder.get_query(), encode_feature, encode_len) asr_prev_token = self.asr.pre_embed(prev_token) decoder_input = torch.cat([asr_prev_token, context], dim=-1) cur_prob, d_state = self.asr.decoder(decoder_input) # Embedding fusion (output shape 1xV) if self.apply_emb: _, cur_prob = self.emb_decoder(d_state, cur_prob, return_loss=False) else: cur_prob = F.log_softmax(cur_prob, dim=-1) att_prob = cur_prob.squeeze(0) #print('att_prob', att_prob.shape) # att_prob torch.Size([31]) # Perform CTC prefix scoring on limited candidates (else OOM easily) if self.apply_ctc: # TODO : Check the performance drop for computing part of candidates only _, ctc_candidates = cur_prob.squeeze(0).topk( self.ctc_beam_size, dim=-1) candidates = ctc_candidates.cpu().tolist() ctc_prob, ctc_state = ctc_prefix.cheap_compute( hypothesis.outIndex, prev_ctc_state, candidates) # TODO : study why ctc_char (slightly) > 0 sometimes ctc_char = torch.FloatTensor( ctc_prob - hypothesis.ctc_prob).to(device) # Combine CTC score and Attention score (HACK: focus on candidates, block others) hack_ctc_char = torch.zeros_like(cur_prob).data.fill_( LOG_ZERO) for idx, char in enumerate(candidates): hack_ctc_char[0, char] = ctc_char[idx] cur_prob = ( 1 - self.ctc_w ) * cur_prob + self.ctc_w * hack_ctc_char # ctc_char cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos> # Joint RNN-LM decoding if self.apply_lm: # assuming batch size always 1, resulting 1x1 lm_input = prev_token.unsqueeze(1) lm_output, lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=prev_lm_state) # assuming batch size always 1, resulting 1xV lm_output = lm_output.squeeze(0) cur_prob += self.lm_w * lm_output.log_softmax(dim=-1) '''no otehr constraint to lengthen transcripy?''' # Beam search # Note: Ignored batch dim. topv, topi = cur_prob.squeeze(0).topk(self.beam_size) #print(topv) #print(topi) prev_attn = self.asr.attention.att_layer.prev_att.cpu( ) if store_att else None final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, ctc_candidates=candidates, att_prob=att_prob) # top : new hypo # Move complete hyps. out # finish hypo (stop) if final is not None and ( t >= min_output_len): # if detect eos, final is not None final_hypothesis.append(final) # finish one beam if self.beam_size == 1: return final_hypothesis # keep finding candidate for hypo next_top_hypothesis.extend( top ) ## collect each hypo's top b candidate, and later pick b top from b*b # Sort for top N beams next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) prev_top_hypothesis = next_top_hypothesis[:self.beam_size] next_top_hypothesis = [] # Rescore all hyp (finished/unfinished) final_hypothesis += prev_top_hypothesis # add the last one ? final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) return final_hypothesis[:self.beam_size]
class BeamDecoder(nn.Module): ''' Beam decoder for ASR ''' def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): super().__init__() # Setup self.beam_size = beam_size self.min_len_ratio = min_len_ratio self.max_len_ratio = max_len_ratio self.asr = asr # ToDo : implement pure ctc decode # assert self.asr.enable_att # Additional decoding modules self.apply_ctc = ctc_weight > 0 if self.apply_ctc: assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' self.ctc_w = ctc_weight self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) self.apply_lm = lm_weight > 0 if self.apply_lm: self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) self.lm.load_state_dict(torch.load( self.lm_path, map_location='cpu')['model']) self.lm.eval() self.apply_emb = emb_decoder is not None if self.apply_emb: self.emb_decoder = emb_decoder def create_msg(self): msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format( self.beam_size, self.min_len_ratio, self.max_len_ratio)] if self.apply_ctc: msg.append( ' |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w)) if self.apply_lm: msg.append(' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format( self.lm_w, self.lm_path)) if self.apply_emb: msg.append(' |Joint Emb. decoding enabled \t| weight = {:.2f}'.format( self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item())) return msg def forward(self, audio_feature, feature_len): # Init. assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search" batch_size = audio_feature.shape[0] device = audio_feature.device dec_state = self.asr.decoder.init_state( batch_size) # Init zero states self.asr.attention.reset_mem() # Flush attention mem # Max output len set w/ hyper param. max_output_len = int( np.ceil(feature_len.cpu().item()*self.max_len_ratio)) # Min output len set w/ hyper param. min_output_len = int( np.ceil(feature_len.cpu().item()*self.min_len_ratio)) # Store attention map if location-aware store_att = self.asr.attention.mode == 'loc' prev_token = torch.zeros( (batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos> # Cache of beam search final_hypothesis, next_top_hypothesis = [], [] # Incase ctc is disabled ctc_state, ctc_prob, candidates, lm_state = None, None, None, None # Encode encode_feature, encode_len = self.asr.encoder( audio_feature, feature_len) # CTC decoding if self.apply_ctc: ctc_output = F.log_softmax( self.asr.ctc_layer(encode_feature), dim=-1) # print(ctc_output.shape) torch.Size([1, 155, 45]) ctc_prefix = CTCPrefixScore(ctc_output) ctc_state = ctc_prefix.init_state() if self.ctc_w == 1.0: # output_seq = ctc_output[0].argmax(dim=-1) # hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq, # output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0, # ctc_state=ctc_state, att_map=None)] # return hypothesis # custom beam for pure CTC beam decode import collections def make_new_beam(): fn = lambda: (-float('inf'), -float('inf'), 0, False, None) return collections.defaultdict(fn) def log_sum(*args): if all(a == -float('inf') for a in args): return -float('inf') a_max = max(args) lsp = torch.tensor(0, dtype=torch.float32) for a in args: lsp += torch.exp(a - a_max) return a_max + torch.log(lsp) def get_final_prob(*args): prob_blank, prob_non_blank, prob_text, applied_lm, lm_state = args prob_total = log_sum(prob_blank, prob_non_blank) return prob_total + prob_text # init beam # prefix, (p_b, p_nb, p_text, applied_lm, lm_state) beam = [(tuple(), (0, -float('inf'), 0, False, None))] # assume batch size is always 1 when decoding ctc_output = ctc_output.squeeze(0) # iterate all timestamp for t in range(ctc_output.shape[0]): # iterate all vocab candidate next_beam = make_new_beam() for vocab in range(ctc_output.shape[1]): p = ctc_output[t, vocab] for prefix, (prob_blank, prob_non_blank, prob_text, applied_lm, lm_state) in beam: # blank case (input: -) *a => *a if vocab == 0: next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix] next_prob_blank = log_sum(next_prob_blank, prob_blank + p, prob_non_blank + p) # text not changed, so the prob_text and lm_state also the same next_beam[prefix] = (next_prob_blank, next_prob_non_blank, prob_text, applied_lm, lm_state) continue # non blank case prev_vocab = prefix[-1] if prefix else None next_prefix = prefix + (vocab, ) next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[next_prefix] if vocab != prev_vocab: # (input: b) *a => *ab next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p, prob_non_blank + p) else: # (input: a) *a- => *aa next_prob_non_blank = log_sum(next_prob_non_blank, prob_blank + p) # apply language model if self.apply_lm and not next_applied_lm: if prev_vocab is None: prev_vocab = 0 lm_input = torch.LongTensor([prev_vocab]).to(device) lm_input = lm_input.unsqueeze(0) lm_output, next_lm_state = self.lm(lm_input, torch.ones([batch_size]), hidden=lm_state) lm_output = lm_output.squeeze(0) next_prob_text = prob_text + self.lm_w * lm_output.log_softmax(dim=-1).squeeze(0)[vocab] next_applied_lm = True next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state) # (input: a) *a => *a if vocab == prev_vocab: # merging all prob of *a to single one next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state = next_beam[prefix] next_prob_non_blank = log_sum(next_prob_non_blank, prob_non_blank + p) next_beam[prefix] = (next_prob_blank, next_prob_non_blank, next_prob_text, next_applied_lm, next_lm_state) beam = sorted(next_beam.items(), key=lambda x: get_final_prob(*x[1]), reverse=True) beam = beam[:self.ctc_beam_size] output_seq = torch.LongTensor(beam[0][0]) # print(output_seq) hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=output_seq, output_scores=[0]*len(output_seq), lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None)] return hypothesis # Start w/ empty hypothesis prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[], output_scores=[], lm_state=None, ctc_prob=0, ctc_state=ctc_state, att_map=None)] # Attention decoding for t in range(max_output_len): for hypothesis in prev_top_hypothesis: # Resume previous step prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( device) self.asr.set_state(prev_dec_state, prev_attn) # Normal asr forward attn, context = self.asr.attention( self.asr.decoder.get_query(), encode_feature, encode_len) asr_prev_token = self.asr.pre_embed(prev_token) decoder_input = torch.cat([asr_prev_token, context], dim=-1) cur_prob, d_state = self.asr.decoder(decoder_input) # Embedding fusion (output shape 1xV) if self.apply_emb: _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False) else: cur_prob = F.log_softmax(cur_prob, dim=-1) # Perform CTC prefix scoring on limited candidates (else OOM easily) if self.apply_ctc: # TODO : Check the performance drop for computing part of candidates only _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1) candidates = ctc_candidates.cpu().tolist() ctc_prob, ctc_state = ctc_prefix.cheap_compute( hypothesis.outIndex, prev_ctc_state, candidates) # TODO : study why ctc_char (slightly) > 0 sometimes ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device) # Combine CTC score and Attention score (HACK: focus on candidates, block others) hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO) for idx, char in enumerate(candidates): hack_ctc_char[0, char] = ctc_char[idx] cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char # ctc_char cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos> # Joint RNN-LM decoding if self.apply_lm: # assuming batch size always 1, resulting 1x1 lm_input = prev_token.unsqueeze(1) lm_output, lm_state = self.lm( lm_input, torch.ones([batch_size]), hidden=prev_lm_state) # assuming batch size always 1, resulting 1xV lm_output = lm_output.squeeze(0) cur_prob += self.lm_w*lm_output.log_softmax(dim=-1) # Beam search # Note: Ignored batch dim. topv, topi = cur_prob.squeeze(0).topk(self.beam_size) prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, ctc_candidates=candidates) # Move complete hyps. out if final is not None and (t >= min_output_len): final_hypothesis.append(final) if self.beam_size == 1: return final_hypothesis next_top_hypothesis.extend(top) # Sort for top N beams next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) prev_top_hypothesis = next_top_hypothesis[:self.beam_size] next_top_hypothesis = [] # Rescore all hyp (finished/unfinished) final_hypothesis += prev_top_hypothesis final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) return final_hypothesis[:self.beam_size]
class CTCBeamDecoder(nn.Module): ''' Beam decoder for ASR (CTC only) ''' def __init__(self, asr, vocab_range, beam_size, vocab_candidate, lm_path='', lm_config='', lm_weight=0.0, device=None): super().__init__() # Setup self.asr = asr self.vocab_range = vocab_range self.beam_size = beam_size self.vocab_cand = vocab_candidate assert self.vocab_cand <= len(self.vocab_range) assert self.asr.enable_ctc # Setup RNNLM self.apply_lm = lm_weight > 0 self.lm_w = 0 if self.apply_lm: self.device = device self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']).to(self.device) self.lm.load_state_dict( torch.load(self.lm_path, map_location='cpu')['model']) self.lm.eval() def create_msg(self): msg = [ 'Decode spec| CTC decoding \t| Beam size = {} \t| LM weight = {}'. format(self.beam_size, self.lm_w) ] return msg def forward(self, feat, feat_len): # Init. assert feat.shape[0] == 1, "Batchsize == 1 is required for beam search" # Calculate CTC output probability ctc_output, encode_len, att_output, att_align, dec_state = \ self.asr(feat, feat_len, 10) del encode_len, att_output, att_align, dec_state, feat_len ctc_output = F.log_softmax(ctc_output[0], dim=-1).cpu().detach().numpy() T = len(ctc_output) # ctc_output = Pr(k,t|x) / dim: T x Vocab # Best W probable sequences B = [CTCHypothesis()] if self.apply_lm: # 0 == <sos> for RNNLM output, hidden = \ self.lm(torch.zeros((1,1),dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), None) B[0].update_lm( (output).log_softmax(dim=-1).squeeze().cpu().numpy(), hidden) start = True for t in range(T): # greedily ignoring pads at the beginning of the sequence if np.argmax(ctc_output[t]) == 0 and start: continue else: start = False B_new = [] for i in range(len(B)): # For y in B B_i_new = copy.deepcopy(B[i]) if B_i_new.get_len() > 0: # If y is not empty if B_i_new.y[-1] == 1: # <eos> = 1 (reached the end) B_new.append(B_i_new) continue B_i_new.update_Pr_nblank(ctc_output[t, B_i_new.y[-1]]) # Find the same prefix for j in range(len(B)): if i != j and B[j].check_same(B_i_new.y[:-1]): lm_prob = 0.0 if self.apply_lm: lm_prob = self.lm_w * B[j].lm_output[ B_i_new.y[-1]] B_i_new.update_Pr_nblank_prefix( ctc_output[t, B_i_new.y[-1]], B[j].Pr_y_t_blank, B[j].Pr_y_t_nblank, lm_prob) break B_i_new.update_Pr_blank(ctc_output[t, 0]) # 0 == <pad> if self.apply_lm: lm_hidden = B_i_new.lm_hidden lm_probs = B_i_new.lm_output else: lm_hidden = None lm_probs = None # Sort the next possible output symbol by CTC (and LM) score if self.apply_lm: ctc_vocab_cand = sorted(zip( self.vocab_range, ctc_output[t, self.vocab_range] + self.lm_w * lm_probs[self.vocab_range]), reverse=True, key=lambda x: x[1]) else: ctc_vocab_cand = sorted(zip( self.vocab_range, ctc_output[t, self.vocab_range]), reverse=True, key=lambda x: x[1]) # Select top K possible symbols to calculate the probabilities for j in range(self.vocab_cand): # <pad>=0, <eos>=1, <unk>=2 k = ctc_vocab_cand[j][0] # Pr(k,t|x) hyp_yk = copy.deepcopy(B_i_new) lm_prob = 0.0 if not self.apply_lm else self.lm_w * lm_probs[ k] hyp_yk.add_token(k, ctc_output[t, k], lm_prob) hyp_yk.updated_lm = False B_new.append(hyp_yk) B_i_new.orig_backup( ) # Retrieve origin prob. before add_token() B_new.append(B_i_new) del B B = [] # Remove duplicated sequences by sorting first (O(NlogN)) B_new = sorted(B_new, key=lambda x: x.get_string()) B.append(B_new[0]) # First Hyp always unique for i in range(1, len(B_new)): if B_new[i].check_same(B[-1].y): # Next Hyp is duplicated, pick the higher one if B_new[i].get_score() > B[-1].get_score(): B[-1] = B_new[i] continue else: # Next Hyp is different, hence valid B.append(B_new[i]) del B_new # Find top W possible sequences if t == T - 1: B = sorted(B, reverse=True, key=lambda x: x.get_final_score()) else: B = sorted(B, reverse=True, key=lambda x: x.get_score()) if len(B) > self.beam_size: B = B[:self.beam_size] # Update LM states if self.apply_lm and t < T - 1: for i in range(len(B)): if B[i].get_len() > 0 and not B[i].updated_lm: output, hidden = \ self.lm(B[i].y[-1] * torch.ones((1,1), dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), B[i].lm_hidden) B[i].update_lm((output).log_softmax( dim=-1).squeeze().cpu().numpy(), hidden) return [b.y for b in B]