def validate(self): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align{}'.format(i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text{}'.format(i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['att', 'ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False) # Resume training self.model.train() if self.emb_decoder is not None: self.emb_decoder.train()
def validate(self): # Eval mode self.model.eval() dev_wer = {'rnnt': []} dev_loss = {'rnnt': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) ''' # Forward model with torch.no_grad(): logits, targets, enc_len, targets_len = self.model(feat, feat_len, txt, txt_len) rnnt_loss = self.rnntloss(logits, targets, enc_len, targets_len) dev_loss['rnnt'].append(rnnt_loss.item()) ''' with torch.no_grad(): rnnt_output = self.model.decode(feat, feat_len) dev_wer['rnnt'].append(cal_er(self.tokenizer, rnnt_output, txt)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log('rnnt_text{}'.format(i), self.tokenizer.decode(rnnt_output[i])) ''' # Ckpt if performance improves for task in ['rnnt']: dev_loss[task] = sum(dev_loss[task])/len(dev_loss[task]) if dev_loss[task] < self.best_loss[task]: self.best_loss[task] = dev_loss[task] self.save_checkpoint('best_{}.pth'.format(task), 'loss', dev_loss[task]) self.write_log('loss', {'dv_'+task: dev_loss[task]}) self.save_checkpoint('latest.pth', 'loss', dev_loss['rnnt'], show_msg=False) ''' for task in ['rnnt']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'wer', dev_wer['rnnt'], show_msg=False) # Resume training self.model.train()
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b * t, -1), txt.contiguous().view(-1)) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', { 'tr_ctc': ctc_loss, 'tr_att': att_loss }) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log( 'wer', { 'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True) }) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 self.log.close()
def finetune(self, filename, fixed_text, max_step=5): # Load data for finetune self.verbose('Total training steps {}.'.format( human_format(max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 accum_count = 0 self.timer.set() step = 0 for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad if max_step == 0: break tf_rate = self.optimizer.pre_step(400000) total_loss = 0 # Fetch data finetune_data = self.fetch_finetune_data(filename, fixed_text) main_batch = self.fetch_data(data) new_batch = self.merge_batch(main_batch, finetune_data) feat, feat_len, txt, txt_len = new_batch self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder( dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight*emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose( 0, 1), txt, encode_len, txt_len) total_loss += ctc_loss*self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1)) total_loss += att_loss*(1-self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) step += 1 # Logger self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log( 'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss}) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log('fuse_lambda', { 'emb': self.emb_decoder.get_weight()}) self.write_log( 'fuse_temp', {'temp': self.emb_decoder.get_temp()}) # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if step > max_step: break ret = self.validate() self.log.close() return ret
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss = None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad # zero grad here tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len = self.model(feat, feat_len) # Compute all objectives if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss = ctc_loss self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] self.write_log( 'per', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True) }) self.write_log( 'wer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True) }) self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1
def validate(self): # Eval mode self.model.eval() dev_per = {'ctc': []} dev_wer = {'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len = self.model(feat, feat_len) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] dev_per['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): #if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_per[task] = sum(dev_per[task]) / len(dev_per[task]) if dev_per[task] < self.best_per[task]: self.best_per[task] = dev_per[task] self.save_checkpoint('best_{}.pth'.format('per'), 'per', dev_per[task]) self.log.log_other('dv_best_per', self.best_per['ctc']) if self.eval_target == 'char' and dev_wer[task] < self.best_wer[ task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format('wer'), 'wer', dev_wer[task]) self.log.log_other('dv_best_wer', self.best_wer['ctc']) self.write_log('per', {'dv_' + task: dev_per[task]}) if self.eval_target == 'char': self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'per', dev_per['ctc'], show_msg=False) if self.paras.save_every: self.save_checkpoint(f'{self.step}.path', 'per', dev_per['ctc'], show_msg=False) # Resume training self.model.train()
def validate(self, _dv_set, _name): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} dev_cer = {'att': [], 'ctc': []} dev_er = {'att': [], 'ctc': []} for i, data in enumerate(_dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) if att_output is not None: dev_wer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='wer')) dev_cer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='cer')) dev_er['att'].append( cal_er(self.tokenizer, att_output, txt, mode=self.val_mode)) if ctc_output is not None: dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) dev_cer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True)) dev_er['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode=self.val_mode, ctc=True)) # Show some example on tensorboard if i == len(_dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text_{}_{}'.format(_name, i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align_{}_{}'.format(_name, i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text_{}_{}'.format(_name, i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text_{}_{}'.format(_name, i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves tasks = [] if len(dev_er['att']) > 0: tasks.append('att') if len(dev_er['ctc']) > 0: tasks.append('ctc') for task in tasks: dev_er[task] = sum(dev_er[task]) / len(dev_er[task]) dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task]) if dev_er[task] < self.best_wer[task][_name]: self.best_wer[task][_name] = dev_er[task] self.save_checkpoint( 'best_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) if self.step >= self.max_step: self.save_checkpoint( 'last_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) self.write_log(self.WER, {'dv_' + task + '_' + _name.lower(): dev_wer[task]}) self.write_log('cer', {'dv_' + task + '_' + _name.lower(): dev_cer[task]}) # if self.transfer_learning: # print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name)) # Resume training self.model.train() if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() if self.emb_decoder is not None: self.emb_decoder.train()
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() self.n_epochs = 0 self.timer.set() '''early stopping for ctc ''' self.early_stoping = self.config['hparas']['early_stopping'] stop_epoch = 10 batch_size = self.config['data']['corpus']['batch_size'] stop_step = len(self.tr_set) * stop_epoch // batch_size while self.step < self.max_step: ctc_loss, att_loss, emb_loss = None, None, None # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data, train=True) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Clear not used objects del att_align # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss else: del dec_state ''' early stopping ctc''' if self.early_stoping: if self.step > stop_step: ctc_output = None self.model.ctc_weight = 0 #print(ctc_output.shape) # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), #[int(encode_len.max()) for _ in encode_len], txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight del encode_len if att_output is not None: #print(att_output.shape) b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss(att_output.view(b * t, -1), txt.view(-1)) # Sum each uttr and devide by length then mean over batch # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float()) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ .format(total_loss.cpu().item(),grad_norm,self.timer.show())) self.write_log('emb_loss', {'tr': emb_loss}) if att_output is not None: self.write_log('loss', {'tr_att': att_loss}) self.write_log(self.WER, { 'tr_att': cal_er(self.tokenizer, att_output, txt) }) self.write_log( 'cer', { 'tr_att': cal_er(self.tokenizer, att_output, txt, mode='cer') }) if ctc_output is not None: self.write_log('loss', {'tr_ctc': ctc_loss}) self.write_log( self.WER, { 'tr_ctc': cal_er( self.tokenizer, ctc_output, txt, ctc=True) }) self.write_log( 'cer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True) }) self.write_log( 'ctc_text_train', self.tokenizer.decode( ctc_output[0].argmax(dim=-1).tolist(), ignore_repeat=True)) # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0: # self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True)) #del total_loss if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): if type(self.dv_set) is list: for dv_id in range(len(self.dv_set)): self.validate(self.dv_set[dv_id], self.dv_names[dv_id]) else: self.validate(self.dv_set, self.dv_names) if self.step % (len(self.tr_set) // batch_size) == 0: # one epoch print('Have finished epoch: ', self.n_epochs) self.n_epochs += 1 if self.lr_scheduler == None: lr = self.optimizer.opt.param_groups[0]['lr'] if self.step == 1: print( '[INFO] using lr schedular defined by Daniel, init lr = ', lr) if self.step > 99999 and self.step % 2000 == 0: lr = lr * 0.85 for param_group in self.optimizer.opt.param_groups: param_group['lr'] = lr print('[INFO] at step:', self.step) print('[INFO] lr reduce to', lr) #self.lr_scheduler.step(total_loss) # End of step # if self.step % EMPTY_CACHE_STEP == 0: # Empty cuda cache after every fixed amount of steps torch.cuda.empty_cache( ) # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 self.timer.set() if self.step > self.max_step: break #update lr_scheduler self.log.close() print('[INFO] Finished training after', human_format(self.max_step), 'steps.')