def progress(self, msg): ''' Verbose function for updating progress on stdout Do not include newline in msg ''' if self.paras.verbose: sys.stdout.write("\033[K") # Clear line print('[Ep {}] {}'.format(human_format(self.cur_epoch), msg), end='\r')
def save_checkpoint(self, f_name, metric, score, name=''): '''' Ckpt saver f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed score - <float> The value of metric used to evaluate model ''' ckpt_path = os.path.join(self.ckpdir, f_name) full_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.get_opt_state_dict(), "global_step": self.step, metric: score } # Additional modules to save #if self.amp: # full_dict['amp'] = self.amp_lib.state_dict() if self.emb_decoder is not None: full_dict['emb_decoder'] = self.emb_decoder.state_dict() torch.save(full_dict, ckpt_path) if len(name) > 0: name = ' on ' + name ckpt_path = '/'.join(ckpt_path.split('/') [6:]) # Set how long the path name to be shown. self.verbose("Saved ckpt (step = {}, {} = {:.2f}) @ {}{}".\ format(human_format(self.step),metric,score,ckpt_path,name))
def save_checkpoint(self, f_name, metric, score, show_msg=True): '''' Ckpt saver f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed score - <float> The value of metric used to evaluate model ''' ckpt_path = os.path.join(self.ckpdir, f_name) full_dict = { "model": self.model.state_dict(), "encoder": self.model.encoder.state_dict(), "decoder": self.model.decoder.state_dict(), "predictor": self.model.predictor.state_dict(), metric: score } # Additional modules to save # if self.amp: # full_dict['amp'] = self.amp_lib.state_dict() # if self.emb_decoder is not None: # full_dict['emb_decoder'] = self.emb_decoder.state_dict() torch.save(full_dict, ckpt_path) if show_msg: self.verbose( "Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". format(human_format(self.step), metric, score, ckpt_path))
def verbose(self, msg, display_step=False): ''' Verbose function for print information to stdout''' header = '[' + human_format( self.step) + ']' if display_step else '[INFO]' if self.paras.verbose: if type(msg) == list: for m in msg: print(header, m.ljust(100)) else: print(header, msg.ljust(100))
def exec(self): ''' Training End-to-end ASR system ''' if self.paras.mode == 'train': self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() ep_len = len(self.tr_set) for ep in range(self.epoch): if ep > 0: # Lr decay if needed self.optimizer.decay() for data in self.tr_set: # Pre-step : do zero_grad self.optimizer.pre_step(self.step) # Fetch data self.timer.cnt('rd') _, audio_feat, audio_len, label = self.fetch_data(data) # Forward pred = self.model(audio_feat) if self.task == 'phn-clf': pred = pred.permute(0, 2, 1) # BxCxT for phn clf loss = self.loss(pred, label) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( ' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100 * float(self.step % ep_len) / ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log(self.task + '_loss', {'tr': loss}) if self.task == 'phn-clf': tr_er = cal_per(pred, label, audio_len)[0] else: tr_er = (pred.argmax(dim=-1) != label) tr_er = tr_er.sum().detach().cpu().float() / len( label) self.write_log(self.task + '_er', {'tr': tr_er}) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() # Test at the end self.validate(test=True) self.log.close()
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) self.timer.set() while self.step < self.max_step: for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) # Compute all objectives lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) self.timer.cnt('fw') # Backprop grad_norm = self.backward(lm_loss) self.step += 1 # Logger if self.step % self.PROGRESS_STEP == 0: self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(lm_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('entropy', {'tr': lm_loss}) self.write_log('perplexity', {'tr': torch.exp(lm_loss).cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break self.log.close()
def save_checkpoint(self, f_name, metric, score, show_msg=True): '''' pt saver f_name - <str> the name of ckpt (w/o prefix), overwrite if existed score - <float> The value of metric used to evaluate model ''' ckpt_path = os.path.join(self.ckpdir, f_name) full_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.get_opt_state_dict(), "global_step": self.step, "epoch": self.cur_epoch, metric: score } torch.save(full_dict, ckpt_path) if show_msg: msg = "Saved checkpoint (epoch = {}, {} = {:.2f}) and status @ {}" self.verbose( msg.format(human_format(self.cur_epoch), metric, score, ckpt_path)) return ckpt_path
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b * t, -1), txt.contiguous().view(-1)) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', { 'tr_ctc': ctc_loss, 'tr_att': att_loss }) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log( 'wer', { 'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True) }) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 self.log.close()
def exec(self): self.verbose( ['Total training steps {}.'.format(human_format(self.max_step))]) self.timer.set() unpair_speech_loss, unpair_text_loss, unsup_pred, unsup_trans, unsup_align = None, None, None, None, None ctc_nan_flag, ignore_speech_flag = 0, 0 tok_usage, gt_usage = [], [] cnter = {'ctc_nan': 0, 'unp_sph': 0, 'unp_txt': 0} while self.step < self.max_step: # --------------------- Load data ----------------------- # # Unpair setting unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = None, None, None, None, None post_pred, asr_post_loss = None, None # For ASR postnet only use_unpair_text = self.unpair_text_weight > 0 and self.step > self.unpair_text_start_step use_unpair_speech = self.unpair_speech_weight > 0 and self.step > self.unpair_speech_start_step tf_rate = self.optimizer.pre_step( self.step) # Catch the returned tf_rate if needed # ToDo : change # of sup. step = 2 x # of unsup. step ? mel, aug_mel, linear, text, sid = self.fetch_data( iter_name='pair_iter') # Load unpaired data only when use_unpair_xxx == True if self.step % 2 == 0: #2 # if True: # ASR first speech_first = True if use_unpair_speech: unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') else: # TTS first speech_first = False if use_unpair_text: cnter['unp_txt'] += 1 unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') total_loss = 0 bs = len(mel) self.timer.cnt('rd') try: # ----------------------- Forward ------------------------ # if speech_first: # Cycle : speech -> text -> speech pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel= unpair_aug_mel) # Check to involve unsupervised Speech2Speech if unpair_latent is not None: # ASR output is the representataion for speech2speech cnter['unp_sph'] += 1 ignore_speech_cycle = False unpaired_teacher = unpair_mel else: # ASR output is all blank (cannot be passed to TTS) only paired text is used ignore_speech_cycle = True unpaired_teacher = None # text -> speech pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text = text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent = unpair_latent, unpaired_text= None, unpaired_latent_len = unpair_latent_len, paired_teacher = mel, unpaired_teacher = unpaired_teacher, tf_rate = tf_rate ) else: # Cycle : text -> speech -> text pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text=text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent=None, unpaired_text=unpair_text, unpaired_latent_len=None, paired_teacher=mel, unpaired_teacher=None, tf_rate=tf_rate ) if use_unpair_text: unpair_mel_pred = unpair_mel_pred.detach( ) # Stop-grad for tts in text2text pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel=unpair_mel_pred, #None, #unpair_mel_pred, #None, #unpaired_mel= unpair_mel_pred, using_fake_mel=use_unpair_text) # Paired ASR loss asr_loss = self.compute_ctcloss(aug_mel, pair_prob, text) if self.model.use_asr_postnet: total_loss = total_loss + self.asr_weight * ( 1 - self.model.asr_postnet_weight) * asr_loss asr_post_loss = self.compute_ctcloss(aug_mel, pair_post_prob, text, apply_log=False) total_loss = total_loss + self.asr_weight * self.model.asr_postnet_weight * asr_post_loss else: total_loss = total_loss + self.asr_weight * asr_loss if math.isnan(asr_loss) or math.isinf(asr_loss): cnter['ctc_nan'] += 1 asr_loss = 0 # Paired TTS loss mel_loss = self.freq_loss(pair_mel_pred, mel) linear_loss = self.freq_loss(pair_linear_pred, linear) tts_loss = mel_loss + linear_loss total_loss = total_loss + self.tts_weight * tts_loss # Unpaired loss if speech_first: # Unpaired speech reconstruction loss if not ignore_speech_cycle: unpair_speech_loss = self.freq_loss(unpair_mel_pred, unpair_mel) +\ self.freq_loss(unpair_linear_pred, unpair_linear) #total_loss += self.unpair_speech_weight*unpair_speech_loss if self.step > self.unpair_speech_start_step: total_loss += self.unpair_speech_weight * unpair_speech_loss elif use_unpair_text: # Unpaired text reconstruction loss ctc_input = (unpair_prob + EPS).transpose(0, 1).log() if self.paras.actual_len: asr_input_len = (unpair_text != 0).sum( dim=-1) * FRAME_PHN_RATIO asr_input_len = asr_input_len + asr_input_len % self.model.n_frames_per_step ctc_len = 1 + (asr_input_len // self.model.time_reduce_factor) else: ctc_len = torch.LongTensor( [unpair_prob.shape[1]] * unpair_prob.shape[0]).to(device=self.device) unpair_text_loss = self.ctc_loss( ctc_input, unpair_text.to_sparse().values(), ctc_len, torch.sum(unpair_text != 0, dim=-1)) if math.isnan(unpair_text_loss) or math.isinf( unpair_text_loss): cnter['ctc_nan'] += 1 unpair_text_loss = 0 total_loss += self.unpair_text_weight * unpair_text_loss # VQ-loss # if vq_loss>0: # total_loss += self.model.vq_weight*vq_loss # if commit_loss>0: # total_loss += self.model.commit_weight*commit_loss # Statics (over unsup. speech only) if speech_first and use_unpair_speech: unsup_pred = unpair_prob.argmax(dim=-1).cpu() unsup_trans = unpair_text.cpu() tok_usage += unsup_pred.flatten().tolist() gt_usage += unsup_trans.flatten().tolist() if unpair_align is not None: unsup_align = unpair_align.detach().cpu() else: unsup_align = [None] * bs self.timer.cnt('fw') # ----------------------- Backward ------------------------ # grad_norm = self.backward(total_loss) # For debugging # if math.isnan(grad_norm): # import IPython # IPython.embed() self.step += 1 # Log if (self.step == 1) or (self.step % self._PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} (CTC-nan/unp-sph/unp-txt={}/{}/{}) | Grad. Norm - {:.2f} | {} '\ .format(total_loss.cpu().item(), cnter['ctc_nan'], cnter['unp_sph'], cnter['unp_txt'], grad_norm, self.timer.show())) self.write_log( 'txt_loss', { 'pair': asr_loss.item() if asr_loss is not None else None, 'unpair': unpair_text_loss.item() if unpair_text_loss is not None else None, 'post': asr_post_loss.item() if asr_post_loss is not None else None }) self.write_log( 'speech_loss', { 'pair': tts_loss.item() if tts_loss is not None else None, 'unpair': unpair_speech_loss.item() if unpair_speech_loss is not None else None }) #self.write_log('stop_err',{'tr':stop_err}) # if commit_loss>0: # self.write_log('commit',{'tr':commit_loss}) # if vq_loss>0: # self.write_log('commit',{'vq':vq_loss}) # self.write_log('temperature',{'temp':self.model.codebook.temp.data}) # self.write_log('ppx',{'tr':cal_ppx(p_code)}) for k in cnter.keys(): cnter[k] = 0 if (self.step == 1) or (self.step % ATTENTION_PLOT_STEP == 0): align = pair_align.cpu() # align shape BxDsxEs sup_pred = pair_prob.argmax(dim=-1).cpu() sup_trans = text.cpu() if self.model.use_asr_postnet: post_pred = pair_post_prob.argmax(dim=-1).cpu() self.write_log( 'per', { 'pair': cal_per(sup_pred, sup_trans), 'unpair': cal_per(unsup_pred, unsup_trans), 'post': cal_per(post_pred, sup_trans) }) self.write_log( 'unpair_hist', data_to_bar(tok_usage, gt_usage, self.vocab_size, self.tokenizer._vocab_list)) for i in range(LISTEN_N_EXAMPLES): self.write_log( 'pair_align{}'.format(i), feat_to_fig(align[i].cpu().detach())) if unsup_align is not None and unsup_align[ i] is not None: self.write_log( 'unpair_align{}'.format(i), feat_to_fig(unsup_align[i].cpu().detach())) tok_usage, gt_usage = [], [] # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break except RuntimeError as e: if 'out of memory' in str(e): self.verbose('WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() else: print(repr(e)) errorout()
def finetune(self, filename, fixed_text, max_step=5): # Load data for finetune self.verbose('Total training steps {}.'.format( human_format(max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 accum_count = 0 self.timer.set() step = 0 for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad if max_step == 0: break tf_rate = self.optimizer.pre_step(400000) total_loss = 0 # Fetch data finetune_data = self.fetch_finetune_data(filename, fixed_text) main_batch = self.fetch_data(data) new_batch = self.merge_batch(main_batch, finetune_data) feat, feat_len, txt, txt_len = new_batch self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder( dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight*emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose( 0, 1), txt, encode_len, txt_len) total_loss += ctc_loss*self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1)) total_loss += att_loss*(1-self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) step += 1 # Logger self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log( 'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss}) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log('fuse_lambda', { 'emb': self.emb_decoder.get_weight()}) self.write_log( 'fuse_temp', {'temp': self.emb_decoder.get_temp()}) # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if step > max_step: break ret = self.validate() self.log.close() return ret
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss = None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad # zero grad here tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len = self.model(feat, feat_len) # Compute all objectives if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss = ctc_loss self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] self.write_log( 'per', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True) }) self.write_log( 'wer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True) }) self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() aug_loss = None ep_len = len(self.tr_set) for ep in range(self.epoch): # Pre-step, decay if ep>0: self.optimizer.decay() for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data _, audio_feat, audio_len = self.fetch_data(data) self.timer.cnt('rd') # Forward real data if 'npc' in self.method: # NPC: input = target pred, _ = self.model(audio_feat) loss = self.loss(pred, audio_feat) # Compute loss on valid part only effective_loss = 0 for i,a_len in enumerate(audio_len): effective_loss += loss[i,:a_len,:].mean(dim=-1).sum() loss = effective_loss/sum(audio_len) else: # APC: input = shifted target audio_len = [l-self.n_future for l in audio_len] pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=False) loss = self.loss(pred, audio_feat[:,self.n_future:,:]) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress(' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100*float(self.step%ep_len)/ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', {'tr': loss}) if (self.step == 1) or (self.step % self.PLOT_STEP == 0): # Perplexity of P(token) g1_ppx, g2_ppx = self.model.report_ppx() self.write_log('ppx', {'group 1':g1_ppx, 'group 2':g2_ppx}) g1_usg, g2_usg = self.model.report_usg() # Empty cache # Plots if self.paras.draw: g1_hist = draw(g1_usg, hist=True) g2_hist = draw(g2_usg, hist=True) self.write_log('VQ Group 1 Hist.',g1_hist) self.write_log('VQ Group 2 Hist.',g2_hist) # Some spectrograms plt_idx = 0 self.write_log('Spectrogram (raw)', draw(audio_feat[plt_idx])) self.write_log('Spectrogram (pred)', draw(pred[plt_idx])) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() self.log.close()
def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() self.n_epochs = 0 self.timer.set() '''early stopping for ctc ''' self.early_stoping = self.config['hparas']['early_stopping'] stop_epoch = 10 batch_size = self.config['data']['corpus']['batch_size'] stop_step = len(self.tr_set) * stop_epoch // batch_size while self.step < self.max_step: ctc_loss, att_loss, emb_loss = None, None, None # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data, train=True) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Clear not used objects del att_align # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss else: del dec_state ''' early stopping ctc''' if self.early_stoping: if self.step > stop_step: ctc_output = None self.model.ctc_weight = 0 #print(ctc_output.shape) # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), #[int(encode_len.max()) for _ in encode_len], txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight del encode_len if att_output is not None: #print(att_output.shape) b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss(att_output.view(b * t, -1), txt.view(-1)) # Sum each uttr and devide by length then mean over batch # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float()) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ .format(total_loss.cpu().item(),grad_norm,self.timer.show())) self.write_log('emb_loss', {'tr': emb_loss}) if att_output is not None: self.write_log('loss', {'tr_att': att_loss}) self.write_log(self.WER, { 'tr_att': cal_er(self.tokenizer, att_output, txt) }) self.write_log( 'cer', { 'tr_att': cal_er(self.tokenizer, att_output, txt, mode='cer') }) if ctc_output is not None: self.write_log('loss', {'tr_ctc': ctc_loss}) self.write_log( self.WER, { 'tr_ctc': cal_er( self.tokenizer, ctc_output, txt, ctc=True) }) self.write_log( 'cer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True) }) self.write_log( 'ctc_text_train', self.tokenizer.decode( ctc_output[0].argmax(dim=-1).tolist(), ignore_repeat=True)) # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0: # self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True)) #del total_loss if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): if type(self.dv_set) is list: for dv_id in range(len(self.dv_set)): self.validate(self.dv_set[dv_id], self.dv_names[dv_id]) else: self.validate(self.dv_set, self.dv_names) if self.step % (len(self.tr_set) // batch_size) == 0: # one epoch print('Have finished epoch: ', self.n_epochs) self.n_epochs += 1 if self.lr_scheduler == None: lr = self.optimizer.opt.param_groups[0]['lr'] if self.step == 1: print( '[INFO] using lr schedular defined by Daniel, init lr = ', lr) if self.step > 99999 and self.step % 2000 == 0: lr = lr * 0.85 for param_group in self.optimizer.opt.param_groups: param_group['lr'] = lr print('[INFO] at step:', self.step) print('[INFO] lr reduce to', lr) #self.lr_scheduler.step(total_loss) # End of step # if self.step % EMPTY_CACHE_STEP == 0: # Empty cuda cache after every fixed amount of steps torch.cuda.empty_cache( ) # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 self.timer.set() if self.step > self.max_step: break #update lr_scheduler self.log.close() print('[INFO] Finished training after', human_format(self.max_step), 'steps.')
def progress(self, msg): ''' Verbose function for updating progress on stdout (do not include newline) ''' if self.paras.verbose: sys.stdout.write("\033[K") # Clear line print('[{}] {}'.format(human_format(self.step), msg), end='\r') sys.stdout.flush()