def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') # Get the next batch ( a list of json files) batch = train_iter.next() self.iteration += 1 x = self.converter(batch, self.device) # Compute the loss at this time step and accumulate it loss = self.model(*x).mean() / self.accum_grad loss.backward() # Backprop # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) loss.detach() # Truncate the graph # update parameters self.forward_count += 1 if self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.grad_clip_threshold) logging.info('grad norm={}'.format(grad_norm)) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() optimizer.zero_grad()
def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator("main") optimizer = self.get_optimizer("main") epoch = train_iter.epoch # Get the next batch (a list of json files) batch = train_iter.next() # self.iteration += 1 # Increase may result in early report, # which is done in other place automatically. x = _recursive_to(batch, self.device) is_new_epoch = train_iter.epoch != epoch # When the last minibatch in the current epoch is given, # gradient accumulation is turned off in order to evaluate the model # on the validation set in every epoch. # see details in https://github.com/espnet/espnet/pull/1388 # Compute the loss at this time step and accumulate it if self.ngpu == 0: loss = self.model(*x).mean() / self.accum_grad else: # apex does not support torch.nn.DataParallel loss = (data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad) if self.use_apex: from apex import amp # NOTE: for a compatibility with noam optimizer opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() else: loss.backward() # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) # update parameters self.forward_count += 1 if not is_new_epoch and self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_threshold) logging.info("grad norm={}".format(grad_norm)) if math.isnan(grad_norm): logging.warning("grad norm is nan. Do not update model.") else: optimizer.step() optimizer.zero_grad()
def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') # Get the next batch ( a list of json files) batch = train_iter.next() # self.iteration += 1 # Increase may result in early report, which is done in other place automatically. x = self.converter(batch, self.device) # Compute the loss at this time step and accumulate it if self.ngpu == 0: loss = self.model(*x).mean() / self.accum_grad else: # apex does not support torch.nn.DataParallel if 'espnet.nets.pytorch_backend.e2e_asr_transformer' in self.model.__class__.__module__: loss = data_parallel(self.model, x + (self.iteration, ), range(self.ngpu)).mean() / self.accum_grad else: loss = data_parallel(self.model, x, range( self.ngpu)).mean() / self.accum_grad if self.use_apex: from apex import amp # NOTE: for a compatibility with noam optimizer opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() else: loss.backward() # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) loss.detach() # Truncate the graph # update parameters self.forward_count += 1 if self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_threshold) logging.info('grad norm={}'.format(grad_norm)) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() optimizer.zero_grad()
def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') # Get the next batch (a list of json files) train_unl_iter = self.get_iterator('sub') labeled_batch = train_iter.next() unlabeled_batch = train_unl_iter.next() # Yield process information for calculating current consistency weight epoch = train_iter.epoch process_info = { 'epoch': epoch, 'current_position': train_iter.current_position, 'batch_len': train_iter.len } # self.iteration += 1 # Increase may result in early report, which is done in other place automatically. labeled_x = _recursive_to(labeled_batch, self.device) unlabeled_x = _recursive_to(unlabeled_batch, self.device) is_new_epoch = train_iter.epoch != epoch # When the last minibatch in the current epoch is given, # gradient accumulation is turned off in order to evaluate the model # on the validation set in every epoch. # see details in https://github.com/espnet/espnet/pull/1388 # Compute the loss at this time step and accumulate it if self.ngpu == 0: loss = self.model(*labeled_x, *unlabeled_x, process_info) else: # apex does not support torch.nn.DataParallel loss = data_parallel(self.model, (*labeled_x, *unlabeled_x, process_info), range(self.ngpu)) loss = loss.mean() / self.accum_grad loss.backward() # learning rate cosine rampdown for SGD optimizer # TODO: make it only for sgd # if epoch > self.cosine_rampdown_starts: # for p in optimizer.param_groups: # p["lr"] *= cosine_rampdown(epoch - self.cosine_rampdown_starts, # self.cosine_rampdown_ends - self.cosine_rampdown_starts) # logging.info("learning rate decayed to " + str(p["lr"])) # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) loss.detach() # Truncate the graph # update parameters self.forward_count += 1 if not is_new_epoch and self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(self.model.enc.parameters(), self.grad_clip_threshold) logging.info('grad norm={}'.format(grad_norm)) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() global_step = (epoch - self.consistency_rampup_starts ) * train_iter.len + train_iter.current_position global_step = global_step if global_step > 0 else 0 if epoch < self.consistency_rampup_starts: update_ema_variables(self.model.enc, self.model.ema_enc, 0, global_step) elif epoch < self.consistency_rampup_ends: update_ema_variables(self.model.enc, self.model.ema_enc, self.ema_pre_decay, global_step) else: update_ema_variables(self.model.enc, self.model.ema_enc, self.ema_post_decay, global_step) optimizer.zero_grad()
def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') asr_optimizer = self.get_optimizer('main') tts_optimizer = self.get_optimizer('tts_opt') # Get the next batch ( a list of json files) batch = train_iter.next() #x = self.converter(batch, self.device) # Compute the loss at this time step and accumulate it if self.ngpu == 0: asr_loss = self.model(x).mean() / self.accum_grad else: # apex does not support torch.nn.DataParallel #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all(): if len(batch[0]) == 3: xs_pad, ilens, ys_pad, spembs = self.converter(batch, self.device) x = (xs_pad, ilens, ys_pad) if 'espnet.nets.pytorch_backend.e2e_asr_transformer' in self.model.__class__.__module__: fake_loss, best_hyps = data_parallel(self.model, x+(self.iteration, True,), range(self.ngpu)) else: fake_loss, best_hyps = data_parallel(self.model, x+(True,), range(self.ngpu)) if self.text_only: ttsasr_loss = data_parallel(self.model, x+(False,True,), range(self.ngpu)).mean() / self.accum_grad # calculate no of nbest and repeat based on it #set_requires_grad(self.tts_model, False) if self.tts: x_tts = self.random_sampler(best_hyps, ilens, xs_pad, spembs) #tts_loss, after_outs, before_outs, logits, att_ws = self.tts_model(*x_tts+(None,True,)) tts_loss, after_outs, before_outs, logits, att_ws = self.tts_model(*x_tts+(True,)) #tts_loss = self.loss_fn_tts(after_outs, before_outs, logits, x_tts[4], x_tts[2]) #comparison with orig hyp #x_tts_orig = self.random_sampler(x[2], x[1], x[0], spembs) #x_tts_orig[0][x_tts_orig[0] == -1] = 0 #tts_loss_j, after_outs_j, before_outs_j, logits_j, att_ws_j = self.tts_model(x_tts_orig[0], x_tts_orig[1], x_tts_orig[2], x_tts_orig[3], x_tts_orig[4], x_tts_orig[5], True) #tts_loss_j = self.loss_fn_tts(after_outs_j, before_outs_j, logits_j, x_tts_orig[3], x_tts_orig[2]) #logging.info("true loss is: " + str(tts_loss_j.mean())) logging.info("fake loss is: " + str(fake_loss.mean())) policy_loss = self.policy_rewards(fake_loss, tts_loss) logging.info('tts_loss: ' + str(float(tts_loss.mean()))) logging.info('policy_loss: ' + str(float(policy_loss.mean()))) asr_loss = policy_loss.mean() / self.accum_grad # asr_loss = tts_loss.mean() / self.accum_grad if self.text_only: asr_loss = asr_loss + ttsasr_loss else: asr_loss = fake_loss.mean() / self.accum_grad logging.info('asr_loss: ' + str(float(asr_loss))) else: xs_pad, ilens, ys_pad = self.converter(batch, self.device) x = (xs_pad, ilens, ys_pad) asr_loss = data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad logging.info('asr_sup_loss: ' + str(float(asr_loss))) if self.use_apex: from apex import amp # NOTE: for a compatibility with noam optimizer opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() else: asr_loss.backward() # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) asr_loss.detach() # Truncate the graph # update parameters self.forward_count += 1 if self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.grad_clip_threshold) logging.info('ASR grad norm={}'.format(grad_norm)) #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all(): if len(batch[0]) == 3: tts_grad_norm = torch.nn.utils.clip_grad_norm_( self.tts_model.parameters(), self.grad_clip_threshold) logging.info('TTS grad norm={}'.format(tts_grad_norm)) if math.isnan(tts_grad_norm): logging.warning('TTS grad norm is nan. Do not update model.') else: if self.update_tts: tts_optimizer.step() if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: asr_optimizer.step() asr_optimizer.zero_grad() #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all(): # cheap trick by BMK if len(batch[0]) == 3: # cheap trick by BMK tts_optimizer.zero_grad()