def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)''' self.asr_model.eval() # Init stats val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0 val_len = 0 all_pred, all_true = [], [] # Perform validation for cur_b, (x, y) in enumerate(self.dev_set): self.progress(' '.join([ 'Valid step -', str(self.step), '(', str(cur_b), '/', str(len(self.dev_set)), ')' ])) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) x = x.to(device=self.device, dtype=torch.float32) y = y.to(device=self.device, dtype=torch.long) state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model( x, ans_len + VAL_STEP, state_len=state_len) # Compute attention loss & get decoding results label = y[:, 1:ans_len + 1].contiguous() if self.ctc_weight < 1: seq_loss = self.seq_loss( att_pred[:, :ans_len, :].contiguous().view( -1, att_pred.shape[-1]), label.view(-1)) seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length seq_loss = torch.mean(seq_loss) # Mean by batch val_att += seq_loss.detach() * int(x.shape[0]) t1, t2 = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(att_pred, label) * int(x.shape[0]) val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int( x.shape[0]) # Compute CTC loss if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) val_ctc += ctc_loss.detach() * int(x.shape[0]) val_len += int(x.shape[0]) # Logger val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc loss_log = {} for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'], [val_loss, val_ctc, val_att]): if v > 0.0: loss_log[k] = v / val_len self.write_log('loss', loss_log) if self.ctc_weight < 1: # Plot attention map to log val_hyp, val_txt = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) val_attmap = draw_att(att_maps, att_pred) # Record loss self.write_log('error rate', {'dev': val_cer / val_len}) self.write_log('acc', {'dev': val_acc / val_len}) for idx, attmap in enumerate(val_attmap): self.write_log('att_' + str(idx), attmap) self.write_log('hyp_' + str(idx), val_hyp[idx]) self.write_log('txt_' + str(idx), val_txt[idx]) # Save model by val er. if val_cer / val_len < self.best_val_ed: self.best_val_ed = val_cer / val_len self.verbose( 'Best val er : {:.4f} @ step {}'.format( self.best_val_ed, self.step)) torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr')) if self.apply_clm: torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm')) # Save hyps. with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f: for t1, t2 in zip(all_pred, all_true): f.write(t1 + ',' + t2 + '\n') self.asr_model.train()
def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)''' self.asr_model.eval() # Init stats val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0 val_len = 0 all_pred, all_true = [],[] progress_bar = tqdm(self.dev_set, leave=False) # Perform validation for cur_b,(x,y,state_len) in enumerate(progress_bar): progress_bar.set_description("[valid {}/{}]".format(self.step, cur_b, len(self.dev_set))) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) if state_len.dim() == 2: state_len = state_len.squeeze(0) x = x.to(device = self.device,dtype=torch.float32) y = y.to(device = self.device,dtype=torch.long) state_len = state_len.to(device = self.device,dtype=torch.long) # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0,dim=-1) # state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len) # Compute attention loss & get decoding results label = y[:,1:ans_len+1].contiguous() if self.ctc_weight < 1: seq_loss = self.seq_loss(att_pred[:,:ans_len,:].contiguous().view(-1,att_pred.shape[-1]),label.view(-1)) seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y != 0, dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length seq_loss = torch.mean(seq_loss) # Mean by batch val_att += seq_loss.detach()*int(x.shape[0]) t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(att_pred,label)*int(x.shape[0]) val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0]) else: # only ctc t1,t2 = cal_cer(ctc_pred,label,mapper=self.mapper,get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(ctc_pred,label)*int(x.shape[0]) val_cer += cal_cer(ctc_pred,label,mapper=self.mapper)*int(x.shape[0]) # Compute CTC loss if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss(F.log_softmax( ctc_pred.transpose(0,1),dim=-1), label, torch.LongTensor(state_len), target_len) val_ctc += ctc_loss.detach()*int(x.shape[0]) val_len += int(x.shape[0]) # Logger val_loss = (1-self.ctc_weight)*val_att + self.ctc_weight*val_ctc loss_log = {} for k,v in zip(['dev_full','dev_ctc','dev_att'],[val_loss, val_ctc, val_att]): if v > 0.0: loss_log[k] = v/val_len self.write_log('loss',loss_log) # attention decoder if self.ctc_weight < 1: # Plot attention map to log val_hyp,val_txt = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True) val_attmap = draw_att(att_maps,att_pred) # Record loss self.write_log('error rate',{'dev':val_cer/val_len}) self.write_log('acc',{'dev':val_acc/val_len}) for idx,attmap in enumerate(val_attmap): self.write_log('att_'+str(idx),attmap) self.write_log('hyp_'+str(idx),val_hyp[idx]) self.write_log('txt_'+str(idx),val_txt[idx]) else: # only use ctc val_hyp, val_txt = cal_cer(ctc_pred, label,mapper=self.mapper,get_sentence=True) # Record loss self.write_log('ctc error rate',{'dev':val_cer/val_len}) self.write_log('ctc acc',{'dev':val_acc/val_len}) for idx in range(len(val_hyp)): self.write_log('hyp_'+str(idx),val_hyp[idx]) self.write_log('txt_'+str(idx),val_txt[idx]) # Save model by val er. self.maybe_dump_checkpoint(val_cer / val_len, all_pred, all_true) self.asr_model.train()
def exec(self): ''' Training End-to-end ASR system''' self.verbose('Training set total ' + str(len(self.train_set)) + ' batches.') while self.step < self.max_step: for x, y in self.train_set: self.progress('Training step - ' + str(self.step)) # Perform teacher forcing rate decaying tf_rate = self.tf_start - self.step * ( self.tf_start - self.tf_end) / self.max_step # Hack bucket, record state length for each uttr, get longest label seq for decode step assert len( x.shape ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD' assert len( y.shape ) == 3, 'Bucketing should cause label have to shape 1xBxT' x = x.squeeze(0).to(device=self.device, dtype=torch.float32) y = y.squeeze(0).to(device=self.device, dtype=torch.long) state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0, axis=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # ASR forwarding self.asr_opt.zero_grad() ctc_pred, state_len, att_pred, _ = self.asr_model( x, ans_len, tf_rate=tf_rate, teacher=y, state_len=state_len) # Calculate loss function loss_log = {} label = y[:, 1:ans_len + 1].contiguous() ctc_loss = 0 att_loss = 0 # CE loss on attention decoder if self.ctc_weight < 1: b, t, c = att_pred.shape att_loss = self.seq_loss(att_pred.view(b * t, c), label.view(-1)) att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length att_loss = torch.mean(att_loss) # Mean by batch loss_log['train_att'] = att_loss # CTC loss on CTC decoder if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) loss_log['train_ctc'] = ctc_loss asr_loss = (1 - self.ctc_weight ) * att_loss + self.ctc_weight * ctc_loss loss_log['train_full'] = asr_loss # Adversarial loss from CLM if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN: if (self.step % self.clm.update_freq) == 0: # update CLM once in a while clm_log, gp = self.clm.train(att_pred.detach(), CLM_MIN_SEQ_LEN) self.write_log('clm_score', clm_log) self.write_log('clm_gp', gp) adv_feedback = self.clm.compute_loss(F.softmax(att_pred)) asr_loss -= adv_feedback # Backprop asr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): self.verbose('Error : grad norm is NaN @ step ' + str(self.step)) else: self.asr_opt.step() # Logger self.write_log('loss', loss_log) if self.ctc_weight < 1: self.write_log('acc', {'train': cal_acc(att_pred, label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('error rate', { 'train': cal_cer(att_pred, label, mapper=self.mapper) }) # Validation if self.step % self.valid_step == 0: self.asr_opt.zero_grad() self.valid() self.step += 1 if self.step > self.max_step: break
def exec(self): ''' Training End-to-end ASR system''' self.verbose('Training set total '+str(len(self.train_set))+' batches.') progress_bar = tqdm(range(self.max_step)) while self.step < self.max_step: for x,y,state_len in self.train_set: progress_bar.set_description("[training {}]".format(self.step)) # Perform teacher forcing rate decaying tf_rate = self.tf_start - self.step*(self.tf_start-self.tf_end)/self.max_step # Hack bucket, record state length for each uttr, get longest label seq for decode step assert len(x.shape) == 4,'Bucketing should cause acoustic feature to have shape 1xBxTxD' assert len(y.shape) == 3,'Bucketing should cause label have to shape 1xBxT' x = x.squeeze(0).to(device = self.device,dtype=torch.float32) y = y.squeeze(0).to(device = self.device,dtype=torch.long) state_len = state_len.squeeze(0).to(device = self.device,dtype=torch.long) # state_len = np.sum(np.sum(x.cpu().data.numpy(),axis=-1) != 0,axis=-1) # state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # ASR forwarding self.asr_opt.zero_grad() ctc_pred, state_len, att_pred, _ = self.asr_model( x, ans_len, tf_rate=tf_rate, teacher=y, state_len=state_len ) # print() # print("inp", x.shape[1] // self.asr_model.encoder.downsample_rate) # print("ctc", ctc_pred.shape[1] if ctc_pred is not None else 0) # print("att", att_pred.shape[1] if att_pred is not None else 0) assert ctc_pred is not None, "for saber, must use ctc_pred!" assert x.shape[1] // self.asr_model.encoder.downsample_rate == ctc_pred.shape[1] # Calculate loss function loss_log = {} label = y[:,1:ans_len+1].contiguous() ctc_loss = 0 att_loss = 0 # CE loss on attention decoder if self.ctc_weight < 1: b,t,c = att_pred.shape att_loss = self.seq_loss(att_pred.view(b*t,c),label.view(-1)) att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y != 0, dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length att_loss = torch.mean(att_loss) # Mean by batch loss_log['train_att'] = att_loss # CTC loss on CTC decoder if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0,1), dim=-1), label, torch.LongTensor(state_len), target_len ) loss_log['train_ctc'] = ctc_loss asr_loss = (1-self.ctc_weight)*att_loss+self.ctc_weight*ctc_loss loss_log['train_full'] = asr_loss # Adversarial loss from CLM if ( self.apply_clm and (att_pred is not None) and (att_pred.shape[1] >= CLM_MIN_SEQ_LEN) ): if (self.step % self.clm.update_freq) == 0: # update CLM once in a while clm_log,gp = self.clm.train(att_pred.detach(),CLM_MIN_SEQ_LEN) self.write_log('clm_score',clm_log) self.write_log('clm_gp',gp) adv_feedback = self.clm.compute_loss(F.softmax(att_pred)) asr_loss -= adv_feedback # Backprop asr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): self.verbose('Error : grad norm is NaN @ step '+str(self.step)) else: self.asr_opt.step() # Logger self.write_log('loss',loss_log) if self.ctc_weight < 1: self.write_log('acc', {'train':cal_acc(att_pred,label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('error rate', {'train':cal_cer(att_pred,label,mapper=self.mapper)}) else: # only ctc self.write_log('ctc acc', {'train':cal_acc(ctc_pred,label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('ctc error rate', {'train':cal_cer(ctc_pred,label,mapper=self.mapper)}) # visualize inputs if self.step % 1000 == 0: example_inputs = x[0].cpu().detach().numpy().transpose(1, 0) img = visualizer.plot(visualizer.plot_item(example_inputs, "inputs-feature")) self.log.add_image("0-inputs", img, global_step=self.step, dataformats="HWC") # Validation if self.step % self.valid_step == 0: self.asr_opt.zero_grad() self.valid() self.step += 1 progress_bar.update() if self.step > self.max_step:break