class Solver(BaseSolver): ''' Solver for training language models''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_loss = 10 def fetch_data(self, data): ''' Move data to device, insert <sos> and compute text seq. length''' txt = torch.cat((torch.zeros( (data.shape[0], 1), dtype=torch.long), data), dim=1).to(self.device) txt_len = torch.sum(data != 0, dim=-1) return txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ load_textset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.model = Prediction(self.vocab_size, **self.config['model']).to(self.device) self.rnnlm = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.verbose(self.rnnlm.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( list(self.model.parameters()) + list(self.rnnlm.parameters()), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step)) def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) self.timer.set() while self.step < self.max_step: for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) # Compute all objectives lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) self.timer.cnt('fw') # Backprop grad_norm = self.backward(lm_loss) self.step += 1 # Logger if self.step % self.PROGRESS_STEP == 0: self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(lm_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('entropy', {'tr': lm_loss}) self.write_log('perplexity', {'tr': torch.exp(lm_loss).cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break self.log.close() def validate(self): # Eval mode self.model.eval() self.rnnlm.eval() dev_loss = [] for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) dev_loss.append(lm_loss) # Ckpt if performance improves dev_loss = sum(dev_loss) / len(dev_loss) dev_ppx = torch.exp(dev_loss).cpu().item() if dev_loss < self.best_loss: self.best_loss = dev_loss self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx) self.write_log('entropy', {'dv': dev_loss}) self.write_log('perplexity', {'dv': dev_ppx}) # Show some example of last batch on tensorboard for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'pred_text{}'.format(i), self.tokenizer.decode(pred[i].argmax(dim=-1).tolist())) # Resume training self.model.train() self.rnnlm.train()
class VqvaeTrainer(BaseSolver): def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Init settings self.step = 0 self.best_tts_loss = 100.0 self.best_per = 2.0 self.asr_weight = self.config['hparas']['asr_weight'] self.tts_weight = self.config['hparas']['tts_weight'] self.unpair_text_start_step = config['hparas'][ 'unpair_text_start_step'] self.unpair_text_weight = self.config['hparas']['unpair_text_weight'] self.unpair_speech_start_step = config['hparas'][ 'unpair_speech_start_step'] self.unpair_speech_weight = self.config['hparas'][ 'unpair_speech_weight'] def fetch_data(self, iter_name): # Load from iterator mel = None while mel is None: try: mel, aug_mel, linear, sid, text = next(getattr( self, iter_name)) except StopIteration: setattr(self, iter_name, iter(getattr(self, iter_name.replace('iter', 'set')))) # Pad to match n_frames_per_step (at least 1 frame padded) pad_len = self.n_frames_per_step - (mel.shape[1] % self.n_frames_per_step) mel = torch.cat( [mel, SPEC_PAD_VALUE * torch.ones_like(mel)[:, :pad_len, :]], dim=1) linear = torch.cat( [linear, SPEC_PAD_VALUE * torch.ones_like(linear)[:, :pad_len, :]], dim=1) return mel.to(self.device),\ aug_mel.to(self.device),\ linear.to(self.device),\ text.to(self.device),\ sid.to(self.device) #return mel.to(self.device, non_blocking=True),\ # aug_mel.to(self.device, non_blocking=True),\ # linear.to(self.device, non_blocking=True),\ # text.to(self.device, non_blocking=True),\ # sid.to(self.device, non_blocking=True) def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.verbose(['Loading data... large corpus may took a while.']) self.unpair_set, self.pair_set, self.dev_set, self.test_set, self.audio_converter, self.tokenizer, data_msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.pair_iter = iter(self.pair_set) self.unpair_iter = iter(self.unpair_set) self.dev_iter = iter(self.dev_set) # Feature statics self.n_mels, self.linear_dim = self.audio_converter.feat_dim self.vocab_size = self.tokenizer.vocab_size self.n_spkr = len( json.load(open(self.config['data']['corpus']['spkr_map']))) self.verbose(data_msg) def set_model(self): ''' Setup Audio AE-model and optimizer ''' # Model self.model = VQVAE(self.n_mels, self.linear_dim, self.vocab_size, self.n_spkr, **self.config['model']).to(self.device) self.n_frames_per_step = self.model.n_frames_per_step self.verbose(self.model.create_msg()) # Objective self.freq_loss = partial( freq_loss, sample_rate=self.audio_converter.sr, n_mels=self.audio_converter.n_mels, loss=self.config['hparas']['freq_loss_type'], differential_loss=self.config['hparas']['differential_loss'], emphasize_linear_low=self.config['hparas']['emphasize_linear_low']) self.ctc_loss = torch.nn.CTCLoss() self.stop_loss = torch.nn.BCEWithLogitsLoss() # Optimizer self.optimizer = Optimizer(self.model.parameters(), **self.config['hparas']) self.verbose(self.optimizer.create_msg()) ### ToDo : unsup first? self.verbose(' | ASR weight = {}\t| start step = {}'.format( self.asr_weight, 0)) self.verbose(' | TTS weight = {}\t| start step = {}'.format( self.tts_weight, 0)) self.verbose(' | Txt weight = {}\t| start step = {}'.format( self.unpair_text_weight, self.unpair_text_start_step)) self.verbose(' | Sph weight = {}\t| start step = {}'.format( self.unpair_speech_weight, self.unpair_speech_start_step)) # ToDo: load pre-trained model if self.paras.load: ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step)) def exec(self): self.verbose( ['Total training steps {}.'.format(human_format(self.max_step))]) self.timer.set() unpair_speech_loss, unpair_text_loss, unsup_pred, unsup_trans, unsup_align = None, None, None, None, None ctc_nan_flag, ignore_speech_flag = 0, 0 tok_usage, gt_usage = [], [] cnter = {'ctc_nan': 0, 'unp_sph': 0, 'unp_txt': 0} while self.step < self.max_step: # --------------------- Load data ----------------------- # # Unpair setting unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = None, None, None, None, None post_pred, asr_post_loss = None, None # For ASR postnet only use_unpair_text = self.unpair_text_weight > 0 and self.step > self.unpair_text_start_step use_unpair_speech = self.unpair_speech_weight > 0 and self.step > self.unpair_speech_start_step tf_rate = self.optimizer.pre_step( self.step) # Catch the returned tf_rate if needed # ToDo : change # of sup. step = 2 x # of unsup. step ? mel, aug_mel, linear, text, sid = self.fetch_data( iter_name='pair_iter') # Load unpaired data only when use_unpair_xxx == True if self.step % 2 == 0: #2 # if True: # ASR first speech_first = True if use_unpair_speech: unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') else: # TTS first speech_first = False if use_unpair_text: cnter['unp_txt'] += 1 unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') total_loss = 0 bs = len(mel) self.timer.cnt('rd') try: # ----------------------- Forward ------------------------ # if speech_first: # Cycle : speech -> text -> speech pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel= unpair_aug_mel) # Check to involve unsupervised Speech2Speech if unpair_latent is not None: # ASR output is the representataion for speech2speech cnter['unp_sph'] += 1 ignore_speech_cycle = False unpaired_teacher = unpair_mel else: # ASR output is all blank (cannot be passed to TTS) only paired text is used ignore_speech_cycle = True unpaired_teacher = None # text -> speech pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text = text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent = unpair_latent, unpaired_text= None, unpaired_latent_len = unpair_latent_len, paired_teacher = mel, unpaired_teacher = unpaired_teacher, tf_rate = tf_rate ) else: # Cycle : text -> speech -> text pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text=text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent=None, unpaired_text=unpair_text, unpaired_latent_len=None, paired_teacher=mel, unpaired_teacher=None, tf_rate=tf_rate ) if use_unpair_text: unpair_mel_pred = unpair_mel_pred.detach( ) # Stop-grad for tts in text2text pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel=unpair_mel_pred, #None, #unpair_mel_pred, #None, #unpaired_mel= unpair_mel_pred, using_fake_mel=use_unpair_text) # Paired ASR loss asr_loss = self.compute_ctcloss(aug_mel, pair_prob, text) if self.model.use_asr_postnet: total_loss = total_loss + self.asr_weight * ( 1 - self.model.asr_postnet_weight) * asr_loss asr_post_loss = self.compute_ctcloss(aug_mel, pair_post_prob, text, apply_log=False) total_loss = total_loss + self.asr_weight * self.model.asr_postnet_weight * asr_post_loss else: total_loss = total_loss + self.asr_weight * asr_loss if math.isnan(asr_loss) or math.isinf(asr_loss): cnter['ctc_nan'] += 1 asr_loss = 0 # Paired TTS loss mel_loss = self.freq_loss(pair_mel_pred, mel) linear_loss = self.freq_loss(pair_linear_pred, linear) tts_loss = mel_loss + linear_loss total_loss = total_loss + self.tts_weight * tts_loss # Unpaired loss if speech_first: # Unpaired speech reconstruction loss if not ignore_speech_cycle: unpair_speech_loss = self.freq_loss(unpair_mel_pred, unpair_mel) +\ self.freq_loss(unpair_linear_pred, unpair_linear) #total_loss += self.unpair_speech_weight*unpair_speech_loss if self.step > self.unpair_speech_start_step: total_loss += self.unpair_speech_weight * unpair_speech_loss elif use_unpair_text: # Unpaired text reconstruction loss ctc_input = (unpair_prob + EPS).transpose(0, 1).log() if self.paras.actual_len: asr_input_len = (unpair_text != 0).sum( dim=-1) * FRAME_PHN_RATIO asr_input_len = asr_input_len + asr_input_len % self.model.n_frames_per_step ctc_len = 1 + (asr_input_len // self.model.time_reduce_factor) else: ctc_len = torch.LongTensor( [unpair_prob.shape[1]] * unpair_prob.shape[0]).to(device=self.device) unpair_text_loss = self.ctc_loss( ctc_input, unpair_text.to_sparse().values(), ctc_len, torch.sum(unpair_text != 0, dim=-1)) if math.isnan(unpair_text_loss) or math.isinf( unpair_text_loss): cnter['ctc_nan'] += 1 unpair_text_loss = 0 total_loss += self.unpair_text_weight * unpair_text_loss # VQ-loss # if vq_loss>0: # total_loss += self.model.vq_weight*vq_loss # if commit_loss>0: # total_loss += self.model.commit_weight*commit_loss # Statics (over unsup. speech only) if speech_first and use_unpair_speech: unsup_pred = unpair_prob.argmax(dim=-1).cpu() unsup_trans = unpair_text.cpu() tok_usage += unsup_pred.flatten().tolist() gt_usage += unsup_trans.flatten().tolist() if unpair_align is not None: unsup_align = unpair_align.detach().cpu() else: unsup_align = [None] * bs self.timer.cnt('fw') # ----------------------- Backward ------------------------ # grad_norm = self.backward(total_loss) # For debugging # if math.isnan(grad_norm): # import IPython # IPython.embed() self.step += 1 # Log if (self.step == 1) or (self.step % self._PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} (CTC-nan/unp-sph/unp-txt={}/{}/{}) | Grad. Norm - {:.2f} | {} '\ .format(total_loss.cpu().item(), cnter['ctc_nan'], cnter['unp_sph'], cnter['unp_txt'], grad_norm, self.timer.show())) self.write_log( 'txt_loss', { 'pair': asr_loss.item() if asr_loss is not None else None, 'unpair': unpair_text_loss.item() if unpair_text_loss is not None else None, 'post': asr_post_loss.item() if asr_post_loss is not None else None }) self.write_log( 'speech_loss', { 'pair': tts_loss.item() if tts_loss is not None else None, 'unpair': unpair_speech_loss.item() if unpair_speech_loss is not None else None }) #self.write_log('stop_err',{'tr':stop_err}) # if commit_loss>0: # self.write_log('commit',{'tr':commit_loss}) # if vq_loss>0: # self.write_log('commit',{'vq':vq_loss}) # self.write_log('temperature',{'temp':self.model.codebook.temp.data}) # self.write_log('ppx',{'tr':cal_ppx(p_code)}) for k in cnter.keys(): cnter[k] = 0 if (self.step == 1) or (self.step % ATTENTION_PLOT_STEP == 0): align = pair_align.cpu() # align shape BxDsxEs sup_pred = pair_prob.argmax(dim=-1).cpu() sup_trans = text.cpu() if self.model.use_asr_postnet: post_pred = pair_post_prob.argmax(dim=-1).cpu() self.write_log( 'per', { 'pair': cal_per(sup_pred, sup_trans), 'unpair': cal_per(unsup_pred, unsup_trans), 'post': cal_per(post_pred, sup_trans) }) self.write_log( 'unpair_hist', data_to_bar(tok_usage, gt_usage, self.vocab_size, self.tokenizer._vocab_list)) for i in range(LISTEN_N_EXAMPLES): self.write_log( 'pair_align{}'.format(i), feat_to_fig(align[i].cpu().detach())) if unsup_align is not None and unsup_align[ i] is not None: self.write_log( 'unpair_align{}'.format(i), feat_to_fig(unsup_align[i].cpu().detach())) tok_usage, gt_usage = [], [] # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break except RuntimeError as e: if 'out of memory' in str(e): self.verbose('WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() else: print(repr(e)) errorout() def validate(self): # Eval mode self.model.eval() dev_tts_loss, dev_per, dev_post_per, dev_stop_err = [], [], [], [] for i in range(len(self.dev_set)): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dev_set))) # Fetch data mel, aug_mel, linear, text, sid = self.fetch_data( iter_name='dev_iter') # Forward model with torch.no_grad(): # test ASR pair_prob, _, _, _, _, pair_post_prob, _ = self.model.speech_to_text( paired_mel=mel, unpaired_mel=None) dev_per.append(cal_per(pair_prob, text)) if pair_post_prob is not None: dev_post_per.append((cal_per(pair_post_prob, text))) # test TTS (Note: absolute dec step now) pair_mel_pred, pair_linear_pred, pair_align, _, _, _, _, _ = \ self.model.text_to_speech(paired_text = text, paired_sid=sid, unpaired_sid=None, unpaired_latent=None, unpaired_text=None, unpaired_latent_len=None, paired_teacher=mel.shape[1], unpaired_teacher=None, tf_rate=0.0) dev_tts_loss.append( self.freq_loss(pair_mel_pred, mel) + self.freq_loss(pair_linear_pred, linear)) if i == len(self.dev_set) // 2: # pick n longest samples in the median batch sample_txt = text.cpu()[:LISTEN_N_EXAMPLES] hyp = pair_prob.argmax(dim=-1).cpu()[:LISTEN_N_EXAMPLES] mel_p = pair_mel_pred.cpu()[:LISTEN_N_EXAMPLES] linear_p = pair_linear_pred.cpu()[:LISTEN_N_EXAMPLES] #post_mel_p = tts_pred.cpu()[:LISTEN_N_EXAMPLES,1] # PostNet product align_p = pair_align.cpu()[:LISTEN_N_EXAMPLES] sample_mel = mel.cpu()[:LISTEN_N_EXAMPLES] sample_linear = linear.cpu()[:LISTEN_N_EXAMPLES] # Ckpt if performance improves dev_tts_loss = sum(dev_tts_loss) / len(dev_tts_loss) dev_per = sum(dev_per) / len(dev_per) dev_post_per = sum(dev_post_per) / len(dev_post_per) if len( dev_post_per) > 0 else None #dev_stop_err = sum(dev_stop_err)/len(dev_stop_err) if self.paras.store_best_per: if dev_per < self.best_per: self.best_per = dev_per self.save_checkpoint('best_per.pth', dev_per) if (dev_post_per is not None) and (dev_post_per < self.best_per): self.best_per = dev_post_per self.save_checkpoint('best_post_per.pth', dev_post_per) else: if dev_tts_loss < self.best_tts_loss: self.best_tts_loss = dev_tts_loss if self.step > 1: self.save_checkpoint('tts_{}.pth'.format(self.step), dev_tts_loss) if dev_per < self.best_per: self.best_per = dev_per if self.step > 1: self.save_checkpoint('asr_{}.pth'.format(self.step), dev_per) if (dev_post_per is not None) and (dev_post_per < self.best_per): self.best_per = dev_post_per self.save_checkpoint( 'best_post_per.pth', dev_post_per ) # Note: didnot recode best per from postnet or not if ((self.step > 1) and (self.step % CKPT_STEP == 0)) and not self.paras.store_best_per: # Regular ckpt self.save_checkpoint('step_{}.pth'.format(self.step), dev_tts_loss) # Logger # Write model output (no G-F-lim if picking per) for i, (m_p, l_p, a_p, h_p) in enumerate(zip(mel_p, linear_p, align_p, hyp)): self.write_log('hyp_text{}'.format(i), self.tokenizer.decode(h_p.tolist())) self.write_log('mel_spec{}'.format(i), feat_to_fig(m_p)) self.write_log('linear_spec{}'.format(i), feat_to_fig(l_p)) self.write_log('dv_align{}'.format(i), feat_to_fig(a_p)) if not self.paras.store_best_per: self.write_log('mel_wave{}'.format(i), self.audio_converter.feat_to_wave(m_p)) self.write_log('linear_wave{}'.format(i), self.audio_converter.feat_to_wave(l_p)) # Write ground truth if self.step == 1: for i, (mel, linear, gt_txt) in enumerate( zip(sample_mel, sample_linear, sample_txt)): self.write_log('truth_text{}'.format(i), self.tokenizer.decode(gt_txt.tolist())) self.write_log('mel_spec{}_gt'.format(i), feat_to_fig(mel)) self.write_log('mel_wave{}_gt'.format(i), self.audio_converter.feat_to_wave(mel)) self.write_log('linear_spec{}_gt'.format(i), feat_to_fig(linear)) self.write_log('linear_wave{}_gt'.format(i), self.audio_converter.feat_to_wave(linear)) self.write_log('speech_loss', {'dev': dev_tts_loss}) self.write_log('per', {'dev': dev_per, 'dev_post': dev_post_per}) self.write_log('codebook', (self.model.codebook.embedding.weight.data, self.tokenizer._vocab_list)) #self.write_log('stop_err',{'dev':dev_stop_err}) # Resume training self.model.train() def compute_ctcloss(self, model_input, model_output, target, apply_log=True): if apply_log: ctc_input = (model_output + EPS).transpose(0, 1).log() else: ctc_input = model_output.transpose(0, 1) if self.paras.actual_len: asr_input_len = torch.sum( (model_input == SPEC_PAD_VALUE).long().sum(dim=-1) != model_input.shape[-1], dim=-1) ctc_len = asr_input_len // self.model.time_reduce_factor ctc_target = target else: ctc_target = target.to_sparse().values() ctc_len = torch.LongTensor( [model_output.shape[1]] * model_output.shape[0]).to(device=self.device) return self.ctc_loss(ctc_input, ctc_target, ctc_len, torch.sum(target != 0, dim=-1))
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_wer = {'att': 3.0, 'ctc': 3.0} # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] def fetch_data(self, data): ''' Move data to device and compute text seq. length''' _, feat, feat_len, txt = data feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum > 0, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # ToDo: other training methods def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b * t, -1), txt.contiguous().view(-1)) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', { 'tr_ctc': ctc_loss, 'tr_att': att_loss }) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log( 'wer', { 'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True) }) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 self.log.close() def validate(self): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align{}'.format(i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text{}'.format(i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['att', 'ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False) # Resume training self.model.train() if self.emb_decoder is not None: self.emb_decoder.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # ToDo : support tr/eval on different corpus assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name'] self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path'] self.config['data']['corpus']['bucketing'] = False # The follow attribute should be identical to training config self.config['data']['audio'] = self.src_config['data']['audio'] self.config['data']['corpus']['train_split'] = self.src_config['data']['corpus']['train_split'] self.config['data']['text'] = self.src_config['data']['text'] self.tokenizer = load_text_encoder(**self.config['data']['text']) self.config['model'] = self.src_config['model'] self.finetune_first = 5 self.best_wer = {'att': 3.0, 'ctc': 3.0} # Output file self.output_file = str(self.ckpdir)+'_{}_{}.csv' # Override batch size for beam decoding self.greedy = self.config['decode']['beam_size'] == 1 self.dealer = Datadealer(self.config['data']['audio']) self.ctc = self.config['decode']['ctc_weight'] == 1.0 if not self.greedy: self.config['data']['corpus']['batch_size'] = 1 else: # ToDo : implement greedy raise NotImplementedError # Logger settings self.logdir = os.path.join(paras.logdir, self.exp_name) self.log = SummaryWriter( self.logdir, flush_secs=self.TB_FLUSH_FREQ) self.timer = Timer() def fetch_data(self, data): ''' Move data to device and compute text seq. length''' _, feat, feat_len, txt = data feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self, batch_size=7): ''' Load data for training/validation, store tokenizer and input/output shape''' prev_batch_size = self.config['data']['corpus']['batch_size'] self.config['data']['corpus']['batch_size'] = batch_size self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) self.config['data']['corpus']['batch_size'] = prev_batch_size self.verbose(msg) def set_model(self): ''' Setup ASR model ''' # Model self.feat_dim = 120 self.vocab_size = 46 init_adadelta = True ''' Setup ASR model and optimizer ''' # Model # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, ** self.src_config['model']).to(self.device) self.verbose(self.model.create_msg()) if self.finetune_first>0: names = ["encoder.layers.%d"%i for i in range(self.finetune_first)] model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}] else: model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and ( self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.src_config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # Beam decoder self.decoder = BeamDecoder( self.model, self.emb_decoder, **self.config['decode']) self.verbose(self.decoder.create_msg()) # del self.model # del self.emb_decoder self.decoder.to(self.device) def exec(self): ''' Testing End-to-end ASR system ''' while True: try: filename = input("Input wav file name: ") if filename == "exit": return feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) # Decode with torch.no_grad(): hyps = self.decoder(feat, feat_len) hyp_seqs = [hyp.outIndex for hyp in hyps] hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs] for txt in hyp_txts: print(txt) except: print("Invalid file") pass def recognize(self, filename): try: feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) # Decode with torch.no_grad(): hyps = self.decoder(feat, feat_len) hyp_seqs = [hyp.outIndex for hyp in hyps] hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs] return hyp_txts[0] except Exception as e: print(e) app.logger.debug(e) return "Invalid file" def fetch_finetune_data(self, filename, fixed_text): feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) text = self.tokenizer.encode(fixed_text) text = torch.tensor(text).to(self.device) text_len = len(text) return [feat, feat_len, text, text_len] def merge_batch(self, main_batch, attach_batch): max_feat_len = max(main_batch[1]) max_text_len = max(main_batch[3]) if attach_batch[0].shape[1] > max_feat_len: # reduce extra long example attach_batch[0] = attach_batch[0][:,:max_feat_len] attach_batch[1][0] = max_feat_len else: # pad to max_feat_len padding = torch.zeros(1, max_feat_len - attach_batch[0].shape[1], attach_batch[0].shape[2], dtype=attach_batch[0].dtype).to(self.device) attach_batch[0] = torch.cat([attach_batch[0], padding], dim=1) if attach_batch[2].shape[0] > max_text_len: attach_batch[2] = attach_batch[2][:max_text_len] main_batch[3][0] = max_text_len else: padding = torch.zeros(max_text_len - attach_batch[2].shape[0], dtype=attach_batch[2].dtype).to(self.device) try: attach_batch[2] = torch.cat([attach_batch[2], padding], dim=0).unsqueeze(0) except: pdb.set_trace() new_batch = ( torch.cat([main_batch[0], attach_batch[0]], dim=0), torch.cat([main_batch[1], attach_batch[1]], dim=0), torch.cat([main_batch[2], attach_batch[2]], dim=0), torch.cat([main_batch[3], torch.tensor([attach_batch[3]]).to(self.device)], dim=0) ) return new_batch def finetune(self, filename, fixed_text, max_step=5): # Load data for finetune self.verbose('Total training steps {}.'.format( human_format(max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 accum_count = 0 self.timer.set() step = 0 for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad if max_step == 0: break tf_rate = self.optimizer.pre_step(400000) total_loss = 0 # Fetch data finetune_data = self.fetch_finetune_data(filename, fixed_text) main_batch = self.fetch_data(data) new_batch = self.merge_batch(main_batch, finetune_data) feat, feat_len, txt, txt_len = new_batch self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder( dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight*emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose( 0, 1), txt, encode_len, txt_len) total_loss += ctc_loss*self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1)) total_loss += att_loss*(1-self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) step += 1 # Logger self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log( 'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss}) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log('fuse_lambda', { 'emb': self.emb_decoder.get_weight()}) self.write_log( 'fuse_temp', {'temp': self.emb_decoder.get_temp()}) # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if step > max_step: break ret = self.validate() self.log.close() return ret def validate(self): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt)) dev_wer['ctc'].append(cal_er(self.tokenizer, ctc_output, txt, ctc=True)) # Show some example on tensorboard if i == len(self.dv_set)//2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if True: self.write_log('true_text{}'.format( i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log('att_align{}'.format(i), feat_to_fig( att_align[i, 0, :, :].cpu().detach())) self.write_log('att_text{}'.format(i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Skip save model here # Ckpt if performance improves to_prints = [] for task in ['att', 'ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: to_print = f"WER of {task}: {dev_wer[task]} < prev best ({self.best_wer[task]})" self.best_wer[task] = dev_wer[task] else: to_print = f"WER of {task}: {dev_wer[task]} >= prev best ({self.best_wer[task]})" print(to_print, flush=True) to_prints.append(to_print) # self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_'+task: dev_wer[task]}) # self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False) # Resume training self.model.train() if self.emb_decoder is not None: self.emb_decoder.train() return '\n'.join(to_prints)
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_wer = {'ctc': 3.0} self.best_per = {'ctc': 3.0} # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg= \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum > 0, **self.config['data']) self.verbose(msg) def transfer_weight(self): # Transfer optimizer ckpt_path = self.config['data']['transfer'].pop('src_ckpt') ckpt = torch.load(ckpt_path, map_location=self.device) #optim_ckpt = ckpt['optimizer'] #for ctc_final_related_param in optim_ckpt['param_groups'][0]['params'][-2:]: # optim_ckpt['state'].pop(ctc_final_related_param) #self.optimizer.load_opt_state_dict(optim_ckpt) # Load weights msg = self.model.transfer_with_mapping(ckpt, self.config['data']['transfer'], self.tokenizer) del ckpt self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) self.eval_target = 'phone' if self.config['data']['corpus'][ 'target'] == 'ipa' else 'char' # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() if self.paras.transfer: self.transfer_weight() # Automatically load pre-trained model if self.paras.load is given if self.paras.load: self.load_ckpt() # ToDo: other training methods def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss = None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad # zero grad here tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len = self.model(feat, feat_len) # Compute all objectives if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss = ctc_loss self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] self.write_log( 'per', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True) }) self.write_log( 'wer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True) }) self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 #self.log.close() def validate(self): # Eval mode self.model.eval() dev_per = {'ctc': []} dev_wer = {'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len = self.model(feat, feat_len) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] dev_per['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): #if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_per[task] = sum(dev_per[task]) / len(dev_per[task]) if dev_per[task] < self.best_per[task]: self.best_per[task] = dev_per[task] self.save_checkpoint('best_{}.pth'.format('per'), 'per', dev_per[task]) self.log.log_other('dv_best_per', self.best_per['ctc']) if self.eval_target == 'char' and dev_wer[task] < self.best_wer[ task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format('wer'), 'wer', dev_wer[task]) self.log.log_other('dv_best_wer', self.best_wer['ctc']) self.write_log('per', {'dv_' + task: dev_per[task]}) if self.eval_target == 'char': self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'per', dev_per['ctc'], show_msg=False) if self.paras.save_every: self.save_checkpoint(f'{self.step}.path', 'per', dev_per['ctc'], show_msg=False) # Resume training self.model.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras): super().__init__(config, paras) # Logger settings self.val_loss = 1000 self.cur_epoch = 0 def fetch_data(self, data): ''' Move data to device ''' file_id, audio_feat, audio_len = data if self.gpu: audio_feat = audio_feat.cuda() return file_id, audio_feat, audio_len def load_data(self): ''' Load data for training/validation ''' self.tr_set, self.dv_set, _, self.audio_dim, msg = \ prepare_data(self.paras.njobs, self.paras.dev_njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup model and optimizer ''' # Model self.method = self.config['model']['method'] if self.method in ['apc','vqapc']: self.n_future = self.config['model']['n_future'] from model.apc import APC as Net elif self.method == 'npc': from model.npc import NPC as Net else: raise NotImplementedError self.model = Net(input_size=self.audio_dim, **self.config['model']['paras']) if self.gpu: self.model = self.model.cuda() self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Loss if 'npc' in self.method: # Avoid reduction for NPC for zero-padding self.loss = torch.nn.L1Loss(reduction='none') else: # APC family have zero-padding with torch API self.loss = torch.nn.L1Loss() if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # ToDo: Data Parallel? # self.model = torch.nn.DataParallel(self.model) self.model.train() def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() aug_loss = None ep_len = len(self.tr_set) for ep in range(self.epoch): # Pre-step, decay if ep>0: self.optimizer.decay() for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data _, audio_feat, audio_len = self.fetch_data(data) self.timer.cnt('rd') # Forward real data if 'npc' in self.method: # NPC: input = target pred, _ = self.model(audio_feat) loss = self.loss(pred, audio_feat) # Compute loss on valid part only effective_loss = 0 for i,a_len in enumerate(audio_len): effective_loss += loss[i,:a_len,:].mean(dim=-1).sum() loss = effective_loss/sum(audio_len) else: # APC: input = shifted target audio_len = [l-self.n_future for l in audio_len] pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=False) loss = self.loss(pred, audio_feat[:,self.n_future:,:]) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress(' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100*float(self.step%ep_len)/ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', {'tr': loss}) if (self.step == 1) or (self.step % self.PLOT_STEP == 0): # Perplexity of P(token) g1_ppx, g2_ppx = self.model.report_ppx() self.write_log('ppx', {'group 1':g1_ppx, 'group 2':g2_ppx}) g1_usg, g2_usg = self.model.report_usg() # Empty cache # Plots if self.paras.draw: g1_hist = draw(g1_usg, hist=True) g2_hist = draw(g2_usg, hist=True) self.write_log('VQ Group 1 Hist.',g1_hist) self.write_log('VQ Group 2 Hist.',g2_hist) # Some spectrograms plt_idx = 0 self.write_log('Spectrogram (raw)', draw(audio_feat[plt_idx])) self.write_log('Spectrogram (pred)', draw(pred[plt_idx])) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() self.log.close() def validate(self): # Eval mode self.model.eval() dev_loss = [] for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set))) # Fetch data _, audio_feat, audio_len = self.fetch_data(data) # Forward model with torch.no_grad(): if 'npc' in self.method: pred, _ = self.model(audio_feat, testing=True) loss = self.loss(pred, audio_feat) # Compute loss on valid part only effective_loss = 0 for i,a_len in enumerate(audio_len): effective_loss += loss[i,:a_len,:].mean(dim=-1).sum() loss = effective_loss/sum(audio_len) else: audio_len = [l-self.n_future for l in audio_len] pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=True) loss = self.loss(pred, audio_feat[:,self.n_future:,:]) dev_loss.append(loss.cpu().item()) # Record metric dev_loss = sum(dev_loss)/len(dev_loss) self.write_log('loss', {'dev':dev_loss}) if dev_loss < self.val_loss: self.val_loss = dev_loss self.save_checkpoint('best_loss.pth', 'loss', dev_loss) # Resume training self.model.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] self.val_mode = self.config['hparas']['val_mode'].lower() self.WER = 'per' if self.val_mode == 'per' else 'wer' def fetch_data(self, data, train=False): ''' Move data to device and compute text seq. length''' # feat: B x T x D _, feat, feat_len, txt = data if self.paras.upstream is not None: # feat is raw waveform device = 'cpu' if self.paras.deterministic else self.device self.upstream.to(device) self.specaug.to(device) def to_device(feat): return [f.to(device) for f in feat] def extract_feature(feat): feat = self.upstream(to_device(feat)) if train and self.config['data']['audio'][ 'augment'] and 'aug' not in self.paras.upstream: feat = [self.specaug(f) for f in feat] return feat if HALF_BATCHSIZE_AUDIO_LEN < 3500 and train: first_len = extract_feature(feat[:1])[0].shape[0] if first_len > HALF_BATCHSIZE_AUDIO_LEN: feat = feat[::2] txt = txt[::2] if self.paras.upstream_trainable: self.upstream.train() feat = extract_feature(feat) else: with torch.no_grad(): self.upstream.eval() feat = extract_feature(feat) feat_len = torch.LongTensor([len(f) for f in feat]) feat = pad_sequence(feat, batch_first=True) txt = pad_sequence(txt, batch_first=True) feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' if self.paras.upstream is not None: print(f'[Solver] - using S3PRL {self.paras.upstream}') self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum>0, **self.config['data']) self.upstream = torch.hub.load( 's3prl/s3prl', self.paras.upstream, feature_selection=self.paras.upstream_feature_selection, refresh=self.paras.upstream_refresh, ckpt=self.paras.upstream_ckpt, force_reload=True, ) self.feat_dim = self.upstream.get_output_dim() self.specaug = Augment() else: self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum>0, **self.config['data']) self.verbose(msg) # Dev set sames self.dv_names = [] if type(self.dv_set) is list: for ds in self.config['data']['corpus']['dev_split']: self.dv_names.append(ds[0]) else: self.dv_names = self.config['data']['corpus']['dev_split'][0] # Logger settings if type(self.dv_names) is str: self.best_wer = { 'att': { self.dv_names: 3.0 }, 'ctc': { self.dv_names: 3.0 } } else: self.best_wer = {'att': {}, 'ctc': {}} for name in self.dv_names: self.best_wer['att'][name] = 3.0 self.best_wer['ctc'][name] = 3.0 def set_model(self): ''' Setup ASR model and optimizer ''' # Model #print(self.feat_dim) #160 batch_size = self.config['data']['corpus']['batch_size'] // 2 self.model = ASR(self.feat_dim, self.vocab_size, batch_size, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses '''label smoothing''' if self.config['hparas']['label_smoothing']: self.seq_loss = LabelSmoothingLoss(31, 0.1) print('[INFO] using label smoothing. ') else: self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) self.ctc_loss = torch.nn.CTCLoss( blank=0, zero_infinity=False) # Note: zero_infinity=False is unstable? # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.lr_scheduler = self.optimizer.lr_scheduler self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Transfer Learning if self.transfer_learning: self.verbose('Apply transfer learning: ') self.verbose(' Train encoder layers: {}'.format( self.train_enc)) self.verbose(' Train decoder: {}'.format( self.train_dec)) self.verbose(' Save name: {}'.format( self.save_name)) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() self.n_epochs = 0 self.timer.set() '''early stopping for ctc ''' self.early_stoping = self.config['hparas']['early_stopping'] stop_epoch = 10 batch_size = self.config['data']['corpus']['batch_size'] stop_step = len(self.tr_set) * stop_epoch // batch_size while self.step < self.max_step: ctc_loss, att_loss, emb_loss = None, None, None # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data, train=True) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Clear not used objects del att_align # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss else: del dec_state ''' early stopping ctc''' if self.early_stoping: if self.step > stop_step: ctc_output = None self.model.ctc_weight = 0 #print(ctc_output.shape) # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), #[int(encode_len.max()) for _ in encode_len], txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight del encode_len if att_output is not None: #print(att_output.shape) b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss(att_output.view(b * t, -1), txt.view(-1)) # Sum each uttr and devide by length then mean over batch # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float()) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ .format(total_loss.cpu().item(),grad_norm,self.timer.show())) self.write_log('emb_loss', {'tr': emb_loss}) if att_output is not None: self.write_log('loss', {'tr_att': att_loss}) self.write_log(self.WER, { 'tr_att': cal_er(self.tokenizer, att_output, txt) }) self.write_log( 'cer', { 'tr_att': cal_er(self.tokenizer, att_output, txt, mode='cer') }) if ctc_output is not None: self.write_log('loss', {'tr_ctc': ctc_loss}) self.write_log( self.WER, { 'tr_ctc': cal_er( self.tokenizer, ctc_output, txt, ctc=True) }) self.write_log( 'cer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True) }) self.write_log( 'ctc_text_train', self.tokenizer.decode( ctc_output[0].argmax(dim=-1).tolist(), ignore_repeat=True)) # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0: # self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True)) #del total_loss if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): if type(self.dv_set) is list: for dv_id in range(len(self.dv_set)): self.validate(self.dv_set[dv_id], self.dv_names[dv_id]) else: self.validate(self.dv_set, self.dv_names) if self.step % (len(self.tr_set) // batch_size) == 0: # one epoch print('Have finished epoch: ', self.n_epochs) self.n_epochs += 1 if self.lr_scheduler == None: lr = self.optimizer.opt.param_groups[0]['lr'] if self.step == 1: print( '[INFO] using lr schedular defined by Daniel, init lr = ', lr) if self.step > 99999 and self.step % 2000 == 0: lr = lr * 0.85 for param_group in self.optimizer.opt.param_groups: param_group['lr'] = lr print('[INFO] at step:', self.step) print('[INFO] lr reduce to', lr) #self.lr_scheduler.step(total_loss) # End of step # if self.step % EMPTY_CACHE_STEP == 0: # Empty cuda cache after every fixed amount of steps torch.cuda.empty_cache( ) # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 self.timer.set() if self.step > self.max_step: break #update lr_scheduler self.log.close() print('[INFO] Finished training after', human_format(self.max_step), 'steps.') def validate(self, _dv_set, _name): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} dev_cer = {'att': [], 'ctc': []} dev_er = {'att': [], 'ctc': []} for i, data in enumerate(_dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) if att_output is not None: dev_wer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='wer')) dev_cer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='cer')) dev_er['att'].append( cal_er(self.tokenizer, att_output, txt, mode=self.val_mode)) if ctc_output is not None: dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) dev_cer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True)) dev_er['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode=self.val_mode, ctc=True)) # Show some example on tensorboard if i == len(_dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text_{}_{}'.format(_name, i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align_{}_{}'.format(_name, i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text_{}_{}'.format(_name, i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text_{}_{}'.format(_name, i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves tasks = [] if len(dev_er['att']) > 0: tasks.append('att') if len(dev_er['ctc']) > 0: tasks.append('ctc') for task in tasks: dev_er[task] = sum(dev_er[task]) / len(dev_er[task]) dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task]) if dev_er[task] < self.best_wer[task][_name]: self.best_wer[task][_name] = dev_er[task] self.save_checkpoint( 'best_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) if self.step >= self.max_step: self.save_checkpoint( 'last_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) self.write_log(self.WER, {'dv_' + task + '_' + _name.lower(): dev_wer[task]}) self.write_log('cer', {'dv_' + task + '_' + _name.lower(): dev_cer[task]}) # if self.transfer_learning: # print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name)) # Resume training self.model.train() if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() if self.emb_decoder is not None: self.emb_decoder.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras): super().__init__(config, paras) # Logger settings self.best_dev_er = 1.0 self.cur_epoch = 0 # Configs following self-supervised learning self.task = self.paras.task assert self.task in ['phn-clf', 'spk-clf'], 'unsupported task' self.ssl_config = yaml.load(open( self.config['model']['feat']['config'], 'r'), Loader=yaml.FullLoader) self.feature = self.ssl_config['model']['method'] if self.feature == 'npc' and 'spec' in self.config['model']['feat']: # NPC has additional option to use unmasked feature self.feat_spec = self.config['model']['feat']['spec'] else: self.feat_spec = None self.config['data']['audio'] = self.ssl_config['data']['audio'] def fetch_data(self, data, train=True): ''' Move data to device ''' file_id, audio_feat, audio_len, label = data if self.gpu: audio_feat = audio_feat.cuda() label = label.cuda() # Extract feature with torch.no_grad(): if self.feat_spec is not None: # Get unmasked feature from particular NPC layer n_layer_feat = int(self.feat_spec.split('-')[-1]) audio_feat = self.feat_extractor.get_unmasked_feat( audio_feat, n_layer_feat) elif self.feature == 'npc': # Get masked feature from NPC _, audio_feat = self.feat_extractor(audio_feat, testing=True) else: # Get feature from APC based model _, audio_feat = self.feat_extractor(audio_feat, audio_len, testing=True) # Mean pool feature for spkr classification if self.task == 'spk-clf': single_feat = [] for a_feat, a_len in zip(audio_feat, audio_len): single_feat.append(a_feat[:a_len].mean(dim=0)) audio_feat = torch.stack(single_feat, dim=0) return file_id, audio_feat, audio_len, label def load_data(self): ''' Load data for training/validation ''' self.tr_set, self.dv_set, self.tt_set, self.audio_dim, msg = \ prepare_data(self.paras.njobs,self.paras.dev_njobs,self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup model and optimizer ''' # Load SSL models for feature extraction self.verbose([' Load feat. extractor ckpt from '\ +self.config['model']['feat']['ckpt']]) if self.feature in ['apc', 'vqapc']: from model.apc import APC as Net elif self.feature == 'npc': from model.npc import NPC as Net if self.feat_spec is not None: self.verbose([' Using specific feature: ' + self.feat_spec]) else: raise NotImplementedError self.feat_extractor = Net(input_size=self.audio_dim, **self.ssl_config['model']['paras']) ckpt = torch.load( self.config['model']['feat']['ckpt'], map_location=self.device if self.mode == 'train' else 'cpu') ckpt['model'] = {k.replace('module.','',1):v \ for k,v in ckpt['model'].items()} self.feat_extractor.load_state_dict(ckpt['model']) # Classifier model self.model = CLF(feat_dim=self.feat_extractor.code_dim, **self.config['model']['clf']) if self.gpu: self.feat_extractor = self.feat_extractor.cuda() self.feat_extractor.eval() self.model = self.model.cuda() model_paras = [{'params': self.model.parameters()}] # Losses ignore_idx = 0 if self.task == 'phn-clf' else -1 self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx) if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) self.load_ckpt() self.model.train() def exec(self): ''' Training End-to-end ASR system ''' if self.paras.mode == 'train': self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() ep_len = len(self.tr_set) for ep in range(self.epoch): if ep > 0: # Lr decay if needed self.optimizer.decay() for data in self.tr_set: # Pre-step : do zero_grad self.optimizer.pre_step(self.step) # Fetch data self.timer.cnt('rd') _, audio_feat, audio_len, label = self.fetch_data(data) # Forward pred = self.model(audio_feat) if self.task == 'phn-clf': pred = pred.permute(0, 2, 1) # BxCxT for phn clf loss = self.loss(pred, label) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( ' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100 * float(self.step % ep_len) / ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log(self.task + '_loss', {'tr': loss}) if self.task == 'phn-clf': tr_er = cal_per(pred, label, audio_len)[0] else: tr_er = (pred.argmax(dim=-1) != label) tr_er = tr_er.sum().detach().cpu().float() / len( label) self.write_log(self.task + '_er', {'tr': tr_er}) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() # Test at the end self.validate(test=True) self.log.close() def validate(self, test=False): # Eval mode self.model.eval() val_loss = [] split = 'dev' val_hit, val_total = 0.0, 0.0 ds = self.tt_set if test else self.dv_set # In training mode, best model is stored in RAM for test # ToDo: load ckpt if test: split = 'test' if self.paras.mode == 'train': self.model = self.best_model if self.gpu: self.model = self.model.cuda() for i, data in enumerate(ds): self.progress('Valid step - {}/{}'.format(i + 1, len(ds))) # Fetch data _, audio_feat, audio_len, label = self.fetch_data(data) # Forward model with torch.no_grad(): # Prediction pred = self.model(audio_feat) if self.task == 'phn-clf': pred = pred.permute(0, 2, 1) # BxCxT # Accumulate batch result val_loss.append(self.loss(pred, label)) if self.task == 'phn-clf': _, hit, total = cal_per(pred, label, audio_len) val_hit += hit val_total += total else: hit = (pred.argmax(dim=-1) == label).sum() val_hit += hit.detach().cpu().float() val_total += len(label) # Write testing prediction if needed if test and self.paras.write_test: if self.task == 'phn-clf': pred = pred.argmax(dim=1).detach().cpu() label = label.cpu() with open(os.path.join(self.ckpdir, self.task + '.csv'), 'a') as f: for p, l, a_len in zip(pred, label, audio_len): for x, y in zip(p[:a_len].tolist(), l[:a_len].tolist()): f.write('{}\t{}\n'.format(x, y)) # Record metric, store ckpt by dev error rate val_loss = sum(val_loss) / len(val_loss) val_er = 1.0 - val_hit / val_total self.write_log(self.task + '_loss', {split: val_loss}) self.write_log(self.task + '_er', {split: val_er}) if split == 'dev' and self.best_dev_er > val_er: self.best_dev_er = val_er self.save_checkpoint('best.pth', self.task + '_er', val_er) self.best_model = copy.deepcopy(self.model.cpu()) # Clone for test # Resume training if self.gpu: self.model = self.model.cuda() self.model.train()