def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding on Attention decoder only)''' val_cer = 0.0 val_len = 0 all_pred,all_true = [],[] ctc_results = [] with torch.no_grad(): for cur_b,(x,y,state_len) in enumerate(self.dev_set): self.progress(' '.join(['Valid step - (',str(cur_b),'/',str(len(self.dev_set)),')'])) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) x = x.to(device = self.device,dtype=torch.float32) y = y.to(device = self.device,dtype=torch.long) # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0, dim=-1) # state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len) ctc_pred = torch.argmax(ctc_pred,dim=-1).cpu() if ctc_pred is not None else None ctc_results.append(ctc_pred) # Result label = y[:,1:ans_len+1].contiguous() t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True) all_pred += t1 all_true += t2 val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0]) val_len += int(x.shape[0]) # Dump model score to ensure model is corrected self.verbose('Validation Error Rate of Current model : {:.4f} '.format(val_cer/val_len)) self.verbose('See {} for validation results.'.format(os.path.join(self.ckpdir,'dev_att_decode.txt'))) with open(os.path.join(self.ckpdir,'dev_att_decode.txt'),'w') as f: for hyp,gt in zip(all_pred,all_true): f.write(gt.lstrip()+'\t'+hyp.lstrip()+'\n') # Also dump CTC result if available if ctc_results[0] is not None: ctc_results = [i for ins in ctc_results for i in ins] ctc_text = [] for pred in ctc_results: p = [i for i in pred.tolist() if i != 0] p = [k for k, g in itertools.groupby(p)] ctc_text.append(self.mapper.translate(p,return_string=True)) self.verbose('Also, see {} for CTC validation results.'.format(os.path.join(self.ckpdir,'dev_ctc_decode.txt'))) with open(os.path.join(self.ckpdir,'dev_ctc_decode.txt'),'w') as f: for hyp,gt in zip(ctc_text,all_true): f.write(gt.lstrip()+'\t'+hyp.lstrip()+'\n')
def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)''' self.asr_model.eval() # Init stats val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0 val_len = 0 all_pred, all_true = [], [] # Perform validation for cur_b, (x, y) in enumerate(self.dev_set): self.progress(' '.join([ 'Valid step -', str(self.step), '(', str(cur_b), '/', str(len(self.dev_set)), ')' ])) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) x = x.to(device=self.device, dtype=torch.float32) y = y.to(device=self.device, dtype=torch.long) state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model( x, ans_len + VAL_STEP, state_len=state_len) # Compute attention loss & get decoding results label = y[:, 1:ans_len + 1].contiguous() if self.ctc_weight < 1: seq_loss = self.seq_loss( att_pred[:, :ans_len, :].contiguous().view( -1, att_pred.shape[-1]), label.view(-1)) seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length seq_loss = torch.mean(seq_loss) # Mean by batch val_att += seq_loss.detach() * int(x.shape[0]) t1, t2 = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(att_pred, label) * int(x.shape[0]) val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int( x.shape[0]) # Compute CTC loss if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) val_ctc += ctc_loss.detach() * int(x.shape[0]) val_len += int(x.shape[0]) # Logger val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc loss_log = {} for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'], [val_loss, val_ctc, val_att]): if v > 0.0: loss_log[k] = v / val_len self.write_log('loss', loss_log) if self.ctc_weight < 1: # Plot attention map to log val_hyp, val_txt = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) val_attmap = draw_att(att_maps, att_pred) # Record loss self.write_log('error rate', {'dev': val_cer / val_len}) self.write_log('acc', {'dev': val_acc / val_len}) for idx, attmap in enumerate(val_attmap): self.write_log('att_' + str(idx), attmap) self.write_log('hyp_' + str(idx), val_hyp[idx]) self.write_log('txt_' + str(idx), val_txt[idx]) # Save model by val er. if val_cer / val_len < self.best_val_ed: self.best_val_ed = val_cer / val_len self.verbose( 'Best val er : {:.4f} @ step {}'.format( self.best_val_ed, self.step)) torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr')) if self.apply_clm: torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm')) # Save hyps. with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f: for t1, t2 in zip(all_pred, all_true): f.write(t1 + ',' + t2 + '\n') self.asr_model.train()
def exec(self): ''' Training End-to-end ASR system''' self.verbose('Training set total ' + str(len(self.train_set)) + ' batches.') while self.step < self.max_step: for x, y in self.train_set: self.progress('Training step - ' + str(self.step)) # Perform teacher forcing rate decaying tf_rate = self.tf_start - self.step * ( self.tf_start - self.tf_end) / self.max_step # Hack bucket, record state length for each uttr, get longest label seq for decode step assert len( x.shape ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD' assert len( y.shape ) == 3, 'Bucketing should cause label have to shape 1xBxT' x = x.squeeze(0).to(device=self.device, dtype=torch.float32) y = y.squeeze(0).to(device=self.device, dtype=torch.long) state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0, axis=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # ASR forwarding self.asr_opt.zero_grad() ctc_pred, state_len, att_pred, _ = self.asr_model( x, ans_len, tf_rate=tf_rate, teacher=y, state_len=state_len) # Calculate loss function loss_log = {} label = y[:, 1:ans_len + 1].contiguous() ctc_loss = 0 att_loss = 0 # CE loss on attention decoder if self.ctc_weight < 1: b, t, c = att_pred.shape att_loss = self.seq_loss(att_pred.view(b * t, c), label.view(-1)) att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length att_loss = torch.mean(att_loss) # Mean by batch loss_log['train_att'] = att_loss # CTC loss on CTC decoder if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) loss_log['train_ctc'] = ctc_loss asr_loss = (1 - self.ctc_weight ) * att_loss + self.ctc_weight * ctc_loss loss_log['train_full'] = asr_loss # Adversarial loss from CLM if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN: if (self.step % self.clm.update_freq) == 0: # update CLM once in a while clm_log, gp = self.clm.train(att_pred.detach(), CLM_MIN_SEQ_LEN) self.write_log('clm_score', clm_log) self.write_log('clm_gp', gp) adv_feedback = self.clm.compute_loss(F.softmax(att_pred)) asr_loss -= adv_feedback # Backprop asr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): self.verbose('Error : grad norm is NaN @ step ' + str(self.step)) else: self.asr_opt.step() # Logger self.write_log('loss', loss_log) if self.ctc_weight < 1: self.write_log('acc', {'train': cal_acc(att_pred, label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('error rate', { 'train': cal_cer(att_pred, label, mapper=self.mapper) }) # Validation if self.step % self.valid_step == 0: self.asr_opt.zero_grad() self.valid() self.step += 1 if self.step > self.max_step: break
def greedy_decode(self, split): '''Perform validation step (!!!NOTE!!! greedy decoding on Attention decoder only)''' # TODO : Add rnnlm & ctc decode to greedy. val_cer = 0.0 val_len = 0 all_pred, all_true = [], [] ctc_results = [] ds = self.dev_set if split == 'dev' else self.test_set # for MWER only tf_data = False if self.mwer: print("TF={}".format(tf_data)) idx = 0 mwer_dir = str( self.pred_path) + '_dev' if split == 'dev' else self.pred_path if not os.path.exists(mwer_dir): os.makedirs(mwer_dir) os.makedirs(mwer_dir + '/data') f = open(os.path.join(mwer_dir, 'data.csv'), 'a') # Origin start with torch.no_grad(): for cur_b, (x, y) in enumerate(tqdm(ds)): #self.progress(' '.join(['Decode step - (',str(cur_b),'/',str(len(ds)),')'])) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) x = x.to(device=self.device, dtype=torch.float32) y = y.to(device=self.device, dtype=torch.long) state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward if self.mwer and tf_data: ctc_pred, state_len, att_pred, att_maps = self.asr_model( x, ans_len, tf_rate=1, teacher=y, state_len=state_len) else: decode_len = ans_len+VAL_STEP if split =='dev' else \ int(np.ceil(max(state_len)*self.decode_step_ratio)) ctc_pred, state_len, att_pred, att_maps = self.asr_model( x, decode_len, state_len=state_len) ctc_pred = torch.argmax( ctc_pred, dim=-1).cpu() if ctc_pred is not None else None ctc_results.append(ctc_pred) ### MWER if self.mwer: att_out = F.softmax(att_pred, dim=-1).cpu().numpy() for output, ans in zip(att_out, y): answer = ans.tolist() answer = answer[1:answer.index(1)] + [1] eos_pos = np.where(np.argmax(output, axis=-1) == 1)[0] if len(eos_pos) > 0: output = output[:eos_pos[0] + 1] f_name = str(idx) + '.npy' f.write("{},{}\n".format( f_name, '_'.join([str(c) for c in answer]))) np.save(str(os.path.join(mwer_dir, 'data', f_name)), output) idx += 1 # Result label = y[:, 1:ans_len + 1].contiguous() t1, t2 = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) all_pred += t1 all_true += t2 val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int( x.shape[0]) val_len += int(x.shape[0]) if split == 'dev': # Dump model score to ensure model is corrected decode_path = os.path.join(self.ckpdir, 'dev_att_decode.txt') er_msg = 'Validation Error Rate of Current model : {:.4f} '.format( val_cer / val_len) save_msg = 'See {} for validation results.'.format(decode_path) else: decode_path = os.path.join(self.ckpdir, self.decode_file + '.txt') er_msg = 'Test Error Rate: {:.4f} '.format(val_cer / val_len) save_msg = 'See {} for decoding results.'.format(decode_path) self.verbose(er_msg) self.verbose(save_msg) ## MWER if self.mwer: f.close() return 0 with open(decode_path, 'w') as f: for hyp, gt in zip(all_pred, all_true): f.write(gt.lstrip() + '\t' + hyp.lstrip() + '\n') # Also dump CTC result if available if ctc_results[0] is not None and split == 'dev': ctc_results = [i for ins in ctc_results for i in ins] ctc_text = [] for pred in ctc_results: p = [i for i in pred.tolist() if i != 0] p = [k for k, g in itertools.groupby(p)] ctc_text.append(self.mapper.translate(p, return_string=True)) self.verbose('Also, see {} for CTC validation results.'.format( os.path.join(self.ckpdir, 'dev_ctc_decode.txt'))) with open(os.path.join(self.ckpdir, 'dev_ctc_decode.txt'), 'w') as f: for hyp, gt in zip(ctc_text, all_true): f.write(gt.lstrip() + '\t' + hyp.lstrip() + '\n')
def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)''' self.asr_model.eval() # Init stats val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0 val_len = 0 all_pred, all_true = [],[] progress_bar = tqdm(self.dev_set, leave=False) # Perform validation for cur_b,(x,y,state_len) in enumerate(progress_bar): progress_bar.set_description("[valid {}/{}]".format(self.step, cur_b, len(self.dev_set))) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) if state_len.dim() == 2: state_len = state_len.squeeze(0) x = x.to(device = self.device,dtype=torch.float32) y = y.to(device = self.device,dtype=torch.long) state_len = state_len.to(device = self.device,dtype=torch.long) # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0,dim=-1) # state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len) # Compute attention loss & get decoding results label = y[:,1:ans_len+1].contiguous() if self.ctc_weight < 1: seq_loss = self.seq_loss(att_pred[:,:ans_len,:].contiguous().view(-1,att_pred.shape[-1]),label.view(-1)) seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y != 0, dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length seq_loss = torch.mean(seq_loss) # Mean by batch val_att += seq_loss.detach()*int(x.shape[0]) t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(att_pred,label)*int(x.shape[0]) val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0]) else: # only ctc t1,t2 = cal_cer(ctc_pred,label,mapper=self.mapper,get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(ctc_pred,label)*int(x.shape[0]) val_cer += cal_cer(ctc_pred,label,mapper=self.mapper)*int(x.shape[0]) # Compute CTC loss if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss(F.log_softmax( ctc_pred.transpose(0,1),dim=-1), label, torch.LongTensor(state_len), target_len) val_ctc += ctc_loss.detach()*int(x.shape[0]) val_len += int(x.shape[0]) # Logger val_loss = (1-self.ctc_weight)*val_att + self.ctc_weight*val_ctc loss_log = {} for k,v in zip(['dev_full','dev_ctc','dev_att'],[val_loss, val_ctc, val_att]): if v > 0.0: loss_log[k] = v/val_len self.write_log('loss',loss_log) # attention decoder if self.ctc_weight < 1: # Plot attention map to log val_hyp,val_txt = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True) val_attmap = draw_att(att_maps,att_pred) # Record loss self.write_log('error rate',{'dev':val_cer/val_len}) self.write_log('acc',{'dev':val_acc/val_len}) for idx,attmap in enumerate(val_attmap): self.write_log('att_'+str(idx),attmap) self.write_log('hyp_'+str(idx),val_hyp[idx]) self.write_log('txt_'+str(idx),val_txt[idx]) else: # only use ctc val_hyp, val_txt = cal_cer(ctc_pred, label,mapper=self.mapper,get_sentence=True) # Record loss self.write_log('ctc error rate',{'dev':val_cer/val_len}) self.write_log('ctc acc',{'dev':val_acc/val_len}) for idx in range(len(val_hyp)): self.write_log('hyp_'+str(idx),val_hyp[idx]) self.write_log('txt_'+str(idx),val_txt[idx]) # Save model by val er. self.maybe_dump_checkpoint(val_cer / val_len, all_pred, all_true) self.asr_model.train()
def exec(self): ''' Training End-to-end ASR system''' self.verbose('Training set total '+str(len(self.train_set))+' batches.') progress_bar = tqdm(range(self.max_step)) while self.step < self.max_step: for x,y,state_len in self.train_set: progress_bar.set_description("[training {}]".format(self.step)) # Perform teacher forcing rate decaying tf_rate = self.tf_start - self.step*(self.tf_start-self.tf_end)/self.max_step # Hack bucket, record state length for each uttr, get longest label seq for decode step assert len(x.shape) == 4,'Bucketing should cause acoustic feature to have shape 1xBxTxD' assert len(y.shape) == 3,'Bucketing should cause label have to shape 1xBxT' x = x.squeeze(0).to(device = self.device,dtype=torch.float32) y = y.squeeze(0).to(device = self.device,dtype=torch.long) state_len = state_len.squeeze(0).to(device = self.device,dtype=torch.long) # state_len = np.sum(np.sum(x.cpu().data.numpy(),axis=-1) != 0,axis=-1) # state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # ASR forwarding self.asr_opt.zero_grad() ctc_pred, state_len, att_pred, _ = self.asr_model( x, ans_len, tf_rate=tf_rate, teacher=y, state_len=state_len ) # print() # print("inp", x.shape[1] // self.asr_model.encoder.downsample_rate) # print("ctc", ctc_pred.shape[1] if ctc_pred is not None else 0) # print("att", att_pred.shape[1] if att_pred is not None else 0) assert ctc_pred is not None, "for saber, must use ctc_pred!" assert x.shape[1] // self.asr_model.encoder.downsample_rate == ctc_pred.shape[1] # Calculate loss function loss_log = {} label = y[:,1:ans_len+1].contiguous() ctc_loss = 0 att_loss = 0 # CE loss on attention decoder if self.ctc_weight < 1: b,t,c = att_pred.shape att_loss = self.seq_loss(att_pred.view(b*t,c),label.view(-1)) att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y != 0, dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length att_loss = torch.mean(att_loss) # Mean by batch loss_log['train_att'] = att_loss # CTC loss on CTC decoder if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0,1), dim=-1), label, torch.LongTensor(state_len), target_len ) loss_log['train_ctc'] = ctc_loss asr_loss = (1-self.ctc_weight)*att_loss+self.ctc_weight*ctc_loss loss_log['train_full'] = asr_loss # Adversarial loss from CLM if ( self.apply_clm and (att_pred is not None) and (att_pred.shape[1] >= CLM_MIN_SEQ_LEN) ): if (self.step % self.clm.update_freq) == 0: # update CLM once in a while clm_log,gp = self.clm.train(att_pred.detach(),CLM_MIN_SEQ_LEN) self.write_log('clm_score',clm_log) self.write_log('clm_gp',gp) adv_feedback = self.clm.compute_loss(F.softmax(att_pred)) asr_loss -= adv_feedback # Backprop asr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): self.verbose('Error : grad norm is NaN @ step '+str(self.step)) else: self.asr_opt.step() # Logger self.write_log('loss',loss_log) if self.ctc_weight < 1: self.write_log('acc', {'train':cal_acc(att_pred,label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('error rate', {'train':cal_cer(att_pred,label,mapper=self.mapper)}) else: # only ctc self.write_log('ctc acc', {'train':cal_acc(ctc_pred,label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('ctc error rate', {'train':cal_cer(ctc_pred,label,mapper=self.mapper)}) # visualize inputs if self.step % 1000 == 0: example_inputs = x[0].cpu().detach().numpy().transpose(1, 0) img = visualizer.plot(visualizer.plot_item(example_inputs, "inputs-feature")) self.log.add_image("0-inputs", img, global_step=self.step, dataformats="HWC") # Validation if self.step % self.valid_step == 0: self.asr_opt.zero_grad() self.valid() self.step += 1 progress_bar.update() if self.step > self.max_step:break