def set_model(self): ''' Setup ASR (and CLM if enabled)''' self.verbose('Init ASR model. Note: validation is done through greedy decoding w/ attention decoder.') # Build attention end-to-end ASR self.asr_model = Seq2Seq(self.sample_x,self.mapper.get_dim(),self.config['asr_model']).to(self.device) if 'VGG' in self.config['asr_model']['encoder']['enc_type']: self.verbose('VCC Extractor in Encoder is enabled, time subsample rate = 4.') self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none').to(self.device) # , reduction='none') # Involve CTC self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean') self.ctc_weight = self.config['asr_model']['optimizer']['joint_ctc'] # TODO: load pre-trained model if self.paras.load: raise NotImplementedError # Setup optimizer if self.apex and self.config['asr_model']['optimizer']['type'] == 'Adam': import apex self.asr_opt = apex.optimizers.FusedAdam(self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate']) else: self.asr_opt = getattr(torch.optim,self.config['asr_model']['optimizer']['type']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate'],eps=1e-8) # Apply CLM if self.apply_clm: self.clm = CLM_wrapper(self.mapper.get_dim(), self.config['clm']).to(self.device) clm_data_config = self.config['solver'] clm_data_config['train_set'] = self.config['clm']['source'] clm_data_config['use_gpu'] = self.paras.gpu self.clm.load_text(clm_data_config) self.verbose('CLM is enabled with text-only source: '+str(clm_data_config['train_set'])) self.verbose('Extra text set total '+str(len(self.clm.train_set))+' batches.')
class Trainer(Solver): ''' Handler for complete training progress''' def __init__(self, config, paras): super(Trainer, self).__init__(config, paras) # Logger Settings self.logdir = os.path.join(paras.logdir, self.exp_name) self.log = SummaryWriter(self.logdir) self.valid_step = config['solver']['dev_step'] self.best_val_ed = 2.0 # Training details self.step = 0 self.max_step = config['solver']['total_steps'] self.tf_start = config['solver']['tf_start'] self.tf_end = config['solver']['tf_end'] self.apex = config['solver']['apex'] # CLM option self.apply_clm = config['clm']['enable'] def load_data(self): ''' Load date for training/validation''' self.verbose('Loading data from ' + self.config['solver']['data_path']) setattr( self, 'train_set', LoadDataset('train', text_only=False, use_gpu=self.paras.gpu, **self.config['solver'])) setattr( self, 'dev_set', LoadDataset('dev', text_only=False, use_gpu=self.paras.gpu, **self.config['solver'])) # Get 1 example for auto constructing model for self.sample_x, _ in getattr(self, 'train_set'): break if len(self.sample_x.shape) == 4: self.sample_x = self.sample_x[0] def set_model(self): ''' Setup ASR (and CLM if enabled)''' self.verbose( 'Init ASR model. Note: validation is done through greedy decoding w/ attention decoder.' ) # Build attention end-to-end ASR self.asr_model = Seq2Seq(self.sample_x, self.mapper.get_dim(), self.config['asr_model']).to(self.device) if 'VGG' in self.config['asr_model']['encoder']['enc_type']: self.verbose( 'VCC Extractor in Encoder is enabled, time subsample rate = 4.' ) self.seq_loss = torch.nn.CrossEntropyLoss( ignore_index=0, reduction='none').to(self.device) #, reduction='none') # Involve CTC self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean') self.ctc_weight = self.config['asr_model']['optimizer']['joint_ctc'] # TODO: load pre-trained model if self.paras.load: raise NotImplementedError # Setup optimizer if self.apex and self.config['asr_model']['optimizer'][ 'type'] == 'Adam': import apex self.asr_opt = apex.optimizers.FusedAdam( self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate']) else: self.asr_opt = getattr( torch.optim, self.config['asr_model']['optimizer']['type']) self.asr_opt = self.asr_opt( self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate'], eps=1e-8) # Apply CLM if self.apply_clm: self.clm = CLM_wrapper(self.mapper.get_dim(), self.config['clm']).to(self.device) clm_data_config = self.config['solver'] clm_data_config['train_set'] = self.config['clm']['source'] clm_data_config['use_gpu'] = self.paras.gpu self.clm.load_text(clm_data_config) self.verbose('CLM is enabled with text-only source: ' + str(clm_data_config['train_set'])) self.verbose('Extra text set total ' + str(len(self.clm.train_set)) + ' batches.') def exec(self): ''' Training End-to-end ASR system''' self.verbose('Training set total ' + str(len(self.train_set)) + ' batches.') while self.step < self.max_step: for x, y in self.train_set: self.progress('Training step - ' + str(self.step)) # Perform teacher forcing rate decaying tf_rate = self.tf_start - self.step * ( self.tf_start - self.tf_end) / self.max_step # Hack bucket, record state length for each uttr, get longest label seq for decode step assert len( x.shape ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD' assert len( y.shape ) == 3, 'Bucketing should cause label have to shape 1xBxT' x = x.squeeze(0).to(device=self.device, dtype=torch.float32) y = y.squeeze(0).to(device=self.device, dtype=torch.long) state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0, axis=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # ASR forwarding self.asr_opt.zero_grad() ctc_pred, state_len, att_pred, _ = self.asr_model( x, ans_len, tf_rate=tf_rate, teacher=y, state_len=state_len) # Calculate loss function loss_log = {} label = y[:, 1:ans_len + 1].contiguous() ctc_loss = 0 att_loss = 0 # CE loss on attention decoder if self.ctc_weight < 1: b, t, c = att_pred.shape att_loss = self.seq_loss(att_pred.view(b * t, c), label.view(-1)) att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length att_loss = torch.mean(att_loss) # Mean by batch loss_log['train_att'] = att_loss # CTC loss on CTC decoder if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) loss_log['train_ctc'] = ctc_loss asr_loss = (1 - self.ctc_weight ) * att_loss + self.ctc_weight * ctc_loss loss_log['train_full'] = asr_loss # Adversarial loss from CLM if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN: if (self.step % self.clm.update_freq) == 0: # update CLM once in a while clm_log, gp = self.clm.train(att_pred.detach(), CLM_MIN_SEQ_LEN) self.write_log('clm_score', clm_log) self.write_log('clm_gp', gp) adv_feedback = self.clm.compute_loss(F.softmax(att_pred)) asr_loss -= adv_feedback # Backprop asr_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): self.verbose('Error : grad norm is NaN @ step ' + str(self.step)) else: self.asr_opt.step() # Logger self.write_log('loss', loss_log) if self.ctc_weight < 1: self.write_log('acc', {'train': cal_acc(att_pred, label)}) if self.step % TRAIN_WER_STEP == 0: self.write_log('error rate', { 'train': cal_cer(att_pred, label, mapper=self.mapper) }) # Validation if self.step % self.valid_step == 0: self.asr_opt.zero_grad() self.valid() self.step += 1 if self.step > self.max_step: break def write_log(self, val_name, val_dict): '''Write log to TensorBoard''' if 'att' in val_name: self.log.add_image(val_name, val_dict, self.step) elif 'txt' in val_name or 'hyp' in val_name: self.log.add_text(val_name, val_dict, self.step) else: self.log.add_scalars(val_name, val_dict, self.step) def valid(self): '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)''' self.asr_model.eval() # Init stats val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0 val_len = 0 all_pred, all_true = [], [] # Perform validation for cur_b, (x, y) in enumerate(self.dev_set): self.progress(' '.join([ 'Valid step -', str(self.step), '(', str(cur_b), '/', str(len(self.dev_set)), ')' ])) # Prepare data if len(x.shape) == 4: x = x.squeeze(0) if len(y.shape) == 3: y = y.squeeze(0) x = x.to(device=self.device, dtype=torch.float32) y = y.to(device=self.device, dtype=torch.long) state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1) state_len = [int(sl) for sl in state_len] ans_len = int(torch.max(torch.sum(y != 0, dim=-1))) # Forward ctc_pred, state_len, att_pred, att_maps = self.asr_model( x, ans_len + VAL_STEP, state_len=state_len) # Compute attention loss & get decoding results label = y[:, 1:ans_len + 1].contiguous() if self.ctc_weight < 1: seq_loss = self.seq_loss( att_pred[:, :ans_len, :].contiguous().view( -1, att_pred.shape[-1]), label.view(-1)) seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\ .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length seq_loss = torch.mean(seq_loss) # Mean by batch val_att += seq_loss.detach() * int(x.shape[0]) t1, t2 = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) all_pred += t1 all_true += t2 val_acc += cal_acc(att_pred, label) * int(x.shape[0]) val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int( x.shape[0]) # Compute CTC loss if self.ctc_weight > 0: target_len = torch.sum(y != 0, dim=-1) ctc_loss = self.ctc_loss( F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label, torch.LongTensor(state_len), target_len) val_ctc += ctc_loss.detach() * int(x.shape[0]) val_len += int(x.shape[0]) # Logger val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc loss_log = {} for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'], [val_loss, val_ctc, val_att]): if v > 0.0: loss_log[k] = v / val_len self.write_log('loss', loss_log) if self.ctc_weight < 1: # Plot attention map to log val_hyp, val_txt = cal_cer(att_pred, label, mapper=self.mapper, get_sentence=True) val_attmap = draw_att(att_maps, att_pred) # Record loss self.write_log('error rate', {'dev': val_cer / val_len}) self.write_log('acc', {'dev': val_acc / val_len}) for idx, attmap in enumerate(val_attmap): self.write_log('att_' + str(idx), attmap) self.write_log('hyp_' + str(idx), val_hyp[idx]) self.write_log('txt_' + str(idx), val_txt[idx]) # Save model by val er. if val_cer / val_len < self.best_val_ed: self.best_val_ed = val_cer / val_len self.verbose( 'Best val er : {:.4f} @ step {}'.format( self.best_val_ed, self.step)) torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr')) if self.apply_clm: torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm')) # Save hyps. with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f: for t1, t2 in zip(all_pred, all_true): f.write(t1 + ',' + t2 + '\n') self.asr_model.train()