def _async_update(self, rank, device_id, grad_denom, is_global_overflow): # if is_global_overflow: # def patch_step(opt): # """this function is copied from apex""" # opt_step = opt.step # # def skip_step(closure=None): # if closure is not None: # raise RuntimeError("Currently, Amp does not support closure use with optimizers.") # #logger.info(f"Device[{self.gpu_rank}] Gradient overflow. Skipping step. " # # "(This is from hack-for-optimizer-sync)") # if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"): # # Clear the master grads that wouldn't be zeroed by model.zero_grad() # for param in opt._amp_stash.all_fp32_from_fp16_params: # param.grad = None # if hasattr(opt, "most_recent_scale"): # opt.most_recent_scale = 1.0 # opt.scale_set_by_backward = False # opt.step = opt_step # opt._amp_stash.already_patched = False # # return skip_step # # # since there is someone in the GPU pool gets overflow, we need to skip one step and keep going # if not self.optim.optimizer._amp_stash.already_patched: # patch_step(self.optim.optimizer) # self.dummy = 'dummy' # else: # is it possible to run out of memory in this case ? ... if self.num_replicas > 1: self._all_reduce_and_rescale_grads(grad_denom=grad_denom) # for param in amp.master_params(optimizer): # param.grad.div_(iters_to_accumulate) normalize_gradients(amp.master_params(self.optim.optimizer), grad_denom * self.args.update_frequency) if self.args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() return [0]
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() dataset = train_data # data iterator: object that controls the # data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed, # num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_tokens, total_loss, total_words = 0, 0, 0 total_non_pads = 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 report_ctc_loss = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 grad_scaler = 1 nan = False nan_counter = 0 if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) # if opt.streaming: # if train_data.is_new_stream(): # streaming_state = self.model.init_stream() # else: # streaming_state = None oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state, nce=opt.nce) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model, vocab_mask=batch.vocab_mask) loss_data = loss_dict['data'] loss = loss_dict[ 'loss'] # a little trick to avoid gradient overflow with fp16 full_loss = loss if opt.ctc_loss > 0.0: ctc_loss = self.ctc_loss_function(outputs, targets) ctc_loss_data = ctc_loss.item() full_loss = full_loss + opt.ctc_loss * ctc_loss report_ctc_loss += ctc_loss_data if opt.mirror_loss: rev_loss = loss_dict['rev_loss'] rev_loss_data = loss_dict['rev_loss_data'] mirror_loss = loss_dict['mirror_loss'] full_loss = full_loss + rev_loss + mirror_loss mirror_loss_data = loss_dict['mirror_loss'].item() else: rev_loss_data = None mirror_loss_data = 0 # reconstruction loss if opt.reconstruct: rec_loss = loss_dict['rec_loss'] rec_loss = rec_loss full_loss = full_loss + rec_loss rec_loss_data = loss_dict['rec_loss_data'] else: rec_loss_data = None # # if opt.lfv_multilingual: # lid_logits = outputs['lid_logits'] # lid_labels = batch.get('target_lang') # lid_loss_function = self.loss_function.get_loss_function('lid_loss') # lid_loss = lid_loss_function(lid_logits, lid_labels) # full_loss = full_loss + lid_loss optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() del outputs except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() num_accumulated_words = 0 num_accumulated_sents = 0 nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) counter = 0 else: nan_counter = 0 if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif 0 < opt.batch_size_update <= num_accumulated_words: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency grad_denom = 1 / grad_scaler if self.opt.normalize_gradient: grad_denom = num_accumulated_words * grad_denom # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) num_words = tgt_size report_loss += loss_data report_tgt_words += num_words report_src_words += src_size total_loss += loss_data total_words += num_words total_tokens += batch.get('target_output').nelement() total_non_pads += batch.get('target_output').ne( onmt.constants.PAD).sum().item() optim = self.optim batch_efficiency = total_non_pads / total_tokens if opt.reconstruct: report_rec_loss += rec_loss_data if opt.mirror_loss: report_rev_loss += rev_loss_data report_mirror_loss += mirror_loss_data if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " % (epoch, i + 1, len(data_iterator), math.exp(report_loss / report_tgt_words))) if opt.reconstruct: rec_ppl = math.exp(report_rec_loss / report_src_words.item()) log_string += (" rec_ppl: %6.2f ; " % rec_ppl) if opt.mirror_loss: rev_ppl = math.exp(report_rev_loss / report_tgt_words) log_string += (" rev_ppl: %6.2f ; " % rev_ppl) # mirror loss per word log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " % (report_src_words / (time.time() - start), report_tgt_words / (time.time() - start))) if opt.ctc_loss > 0.0: ctc_loss = report_ctc_loss / report_tgt_words log_string += (" ctcloss: %8.2f ; " % ctc_loss) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_words, report_src_words = 0, 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_ctc_loss = 0 start = time.time() i = i + 1 return total_loss / total_words
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model_ae.train() self.loss_function_ae.train() self.lat_dis.train() self.loss_lat_dis.train() # Clear the gradients of the model # self.runner.zero_grad() self.model_ae.zero_grad() self.lat_dis.zero_grad() dataset = train_data data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_loss_ae, total_loss_lat_dis, total_frames, total_loss_adv = 0, 0, 0, 0 report_loss_ae, report_loss_lat_dis, report_loss_adv, report_tgt_frames, report_sent = 0, 0, 0, 0, 0 report_dis_frames, report_adv_frames = 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 step = 0 num_accumulated_sents = 0 grad_scaler = -1 nan = False nan_counter = 0 n_step_ae = opt.update_frequency n_step_lat_dis = opt.update_frequency mode_ae = True loss_lat_dis = 0.0 i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded # while not data_iterator.end_of_epoch(): while True: if data_iterator.end_of_epoch(): data_iterator = generate_data_iterator( dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) batch_size = batch.size if grad_scaler == -1: grad_scaler = 1 # if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility # targets = batch.get('target_output') # tgt_mask = targets.ne(onmt.constants.PAD) if mode_ae: step = self.optim_ae._step loss_ae, loss_adv, encoder_outputs = self.autoencoder_backward( batch, step) else: loss_lat_dis, encoder_outputs = self.lat_dis_backward( batch) except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 else: raise e if loss_ae != loss_ae: # catching NAN problem oom = True self.model_ae.zero_grad() self.optim_ae.zero_grad() # self.lat_dis.zero_grad() # self.optim_lat_dis.zero_grad() nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size if mode_ae: report_adv_frames += encoder_outputs['src_mask'].sum( ).item() report_loss_adv += loss_adv else: report_dis_frames += encoder_outputs['src_mask'].sum( ).item() report_loss_lat_dis += loss_lat_dis counter = counter + 1 # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler else: grad_denom = 1.0 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients( amp.master_params(self.optim_ae.optimizer), grad_denom) normalize_gradients( amp.master_params(self.optim_lat_dis.optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim_ae.optimizer), self.opt.max_grad_norm) # torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim_lat_dis.optimizer), # self.opt.max_grad_norm) torch.nn.utils.clip_grad_value_( amp.master_params(self.optim_lat_dis.optimizer), 0.01) if mode_ae: self.optim_ae.step() self.optim_ae.zero_grad() self.model_ae.zero_grad() self.optim_lat_dis.zero_grad() self.lat_dis.zero_grad() else: self.optim_lat_dis.step() self.optim_lat_dis.zero_grad() self.lat_dis.zero_grad() self.optim_ae.zero_grad() self.model_ae.zero_grad() num_updates = self.optim_ae._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every and mode_ae: # if mode_ae is not here then it will continuously save valid_loss_ae, valid_loss_lat_dis = self.eval( self.valid_data) print('Validation loss ae: %g' % valid_loss_ae) print('Validation loss latent discriminator: %g' % valid_loss_lat_dis) self.save(0, valid_loss_ae, itr=data_iterator) if num_updates == 1000000: break mode_ae = not mode_ae counter = 0 # num_accumulated_words = 0 grad_scaler = -1 num_updates = self.optim_ae._step report_loss_ae += loss_ae # report_tgt_words += num_words num_accumulated_sents += batch_size report_sent += batch_size total_frames += src_size report_tgt_frames += src_size total_loss_ae += loss_ae total_loss_lat_dis += loss_lat_dis total_loss_adv += loss_adv optim_ae = self.optim_ae optim_lat_dis = self.optim_lat_dis # batch_efficiency = total_non_pads / total_tokens if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ( "loss_ae : %6.2f ; loss_lat_dis : %6.2f, loss_adv : %6.2f " % (report_loss_ae / report_tgt_frames, report_loss_lat_dis / (report_dis_frames + 1e-5), report_loss_adv / (report_adv_frames + 1e-5))) log_string += ( "lr_ae: %.7f ; updates: %7d; " % (optim_ae.getLearningRate(), optim_ae._step)) log_string += ( "lr_lat_dis: %.7f ; updates: %7d; " % (optim_lat_dis.getLearningRate(), optim_lat_dis._step)) # log_string += ("%5.0f src tok/s " % (report_tgt_frames / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss_ae = 0 report_loss_lat_dis = 0 report_loss_adv = 0 report_tgt_frames = 0 report_dis_frames = 0 report_adv_frames = 0 report_sent = 0 start = time.time() i = i + 1 return total_loss_ae / total_frames * 100, total_loss_lat_dis / n_samples * 100, total_loss_adv / n_samples * 100
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() dataset = train_data data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_loss, total_frames = 0, 0 report_loss, report_tgt_frames, report_sent = 0, 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 num_accumulated_sents = 0 grad_scaler = -1 nan = False nan_counter = 0 i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) if grad_scaler == -1: grad_scaler = 1 # if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility # targets = batch.get('target_output') # tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch) gate_padded = batch.get('gate_padded') if self.opt.n_frames_per_step > 1: slice = torch.arange(0, gate_padded.size(1), self.opt.n_frames_per_step) gate_padded = gate_padded[:, slice] src_org = batch.get('source_org') src_org = src_org.narrow(2, 1, src_org.size(2) - 1) target = [src_org.permute(1, 2, 0).contiguous(), gate_padded] loss = self.loss_function(outputs, target) batch_size = batch.size loss_data = loss.data.item() # a little trick to avoid gradient overflow with fp16 full_loss = loss optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() del outputs except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size counter = counter + 1 # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler # if self.opt.normalize_gradient: # grad_denom = num_accumulated_words * grad_denom else: grad_denom = 1.0 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 # num_accumulated_words = 0 grad_scaler = -1 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) report_loss += loss_data # report_tgt_words += num_words num_accumulated_sents += batch_size report_sent += batch_size total_frames += src_size report_tgt_frames += src_size total_loss += loss_data optim = self.optim # batch_efficiency = total_non_pads / total_tokens if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ( "Epoch %2d, %5d/%5d; ; loss : %6.2f ; " % (epoch, i + 1, len(data_iterator), report_loss)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) # log_string += ("%5.0f src tok/s " % (report_tgt_frames / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_frames = 0 report_sent = 0 start = time.time() i = i + 1 return total_loss / n_samples * 100
def train_epoch(self, epoch, resume=False, batch_order=None, iteration=0): opt = self.opt train_data = self.train_data streaming = opt.streaming # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() if resume: train_data.batch_order = batch_order train_data.set_index(iteration) print("Resuming from iteration: %d" % iteration) else: batch_order = train_data.create_order() iteration = 0 total_tokens, total_loss, total_words = 0, 0, 0 total_non_pads = 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 start = time.time() n_samples = len(train_data) counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 denom = 3584 nan = False if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None for i in range(iteration, n_samples): curriculum = (epoch < opt.curriculum) batches = [train_data.next(curriculum=curriculum)[0]] if (len(self.additional_data) > 0 and i % self.additional_data_ratio[0] == 0): for j in range(len(self.additional_data)): for k in range(self.additional_data_ratio[j + 1]): if self.additional_data_iteration[j] == len( self.additional_data[j]): self.additional_data_iteration[j] = 0 self.additional_data[j].shuffle() self.additional_batch_order[ j] = self.additional_data[j].create_order() batches.append(self.additional_data[j].next()[0]) self.additional_data_iteration[j] += 1 for b in range(len(batches)): batch = batches[b] if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) # if opt.streaming: # if train_data.is_new_stream(): # streaming_state = self.model.init_stream() # else: # streaming_state = None oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.data.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model) loss_data = loss_dict['data'] loss = loss_dict['loss'].div( denom ) # a little trick to avoid gradient overflow with fp16 optimizer = self.optim.optimizer if self.cuda: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch' ) oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() num_accumulated_words = 0 num_accumulated_sents = 0 if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches update_flag = False if 0 < opt.batch_size_update <= num_accumulated_words: update_flag = True elif counter >= opt.update_frequency and 0 >= opt.batch_size_update: update_flag = True elif i == n_samples - 1: # update for the last minibatch update_flag = True if update_flag: grad_denom = 1 / denom if self.opt.normalize_gradient: grad_denom = num_accumulated_words / denom normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step(grad_denom=grad_denom) self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ( (float(i) + 1.) / n_samples) self.save(ep, valid_ppl, batch_order=batch_order, iteration=i) num_words = tgt_size report_loss += loss_data report_tgt_words += num_words report_src_words += src_size total_loss += loss_data total_words += num_words total_tokens += batch.get('target_output').nelement() total_non_pads += batch.get('target_output').ne( onmt.constants.PAD).sum().item() optim = self.optim batch_efficiency = total_non_pads / total_tokens if b == 0 and (i == 0 or (i % opt.log_interval == -1 % opt.log_interval)): print(( "Epoch %2d, %5d/%5d; ; ppl: %6.2f ; lr: %.7f ; num updates: %7d " + "%5.0f src tok/s; %5.0f tgt tok/s; %s elapsed") % (epoch, i + 1, len(train_data), math.exp(report_loss / report_tgt_words), optim.getLearningRate(), optim._step, report_src_words / (time.time() - start), report_tgt_words / (time.time() - start), str( datetime.timedelta( seconds=int(time.time() - self.start_time))))) report_loss, report_tgt_words = 0, 0 report_src_words = 0 start = time.time() return total_loss / total_words
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming # Clear the gradients of the model self.model.zero_grad() # self.model.module.reset_states() dataset = train_data data_iterator = generate_data_iterator(dataset, self.rank, self.world_size, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) # TODO: fix resume which is currently buggy if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_tokens, total_loss, total_words = zero_tensor(), zero_tensor( ), zero_tensor() total_non_pads = zero_tensor() report_loss, report_tgt_words = zero_tensor(), zero_tensor() report_ctc_loss = zero_tensor() report_src_words = zero_tensor() report_rec_loss, report_rev_loss, report_mirror_loss = zero_tensor( ), zero_tensor(), zero_tensor() start = time.time() n_samples = len(data_iterator) counter = 0 num_accumulated_words = zero_tensor() num_accumulated_sents = zero_tensor() grad_div = 1 nan = False nan_counter = zero_tensor() if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded i = i * self.world_size while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm # TODO: move everything to the multiGPU trainer samples = next(epoch_iterator) batch = prepare_sample(samples, device=self.device) if opt.streaming: if train_data.is_new_stream(): streaming_state = self.model.init_stream() else: streaming_state = None # TODO: dealing with oom during distributed training oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility counter = counter + 1 reduction_disabled = False if counter >= opt.update_frequency or i == ( n_samples - 1) else True def maybe_no_sync(): if not reduction_disabled and isinstance( self.model, DDP_model): return self.model.no_sync() else: # when we dont reach the updating step, we do not need to synchronize the gradients # thus disabling the backward grad sync to improve speed return contextlib.ExitStack() # dummy contextmanager with maybe_no_sync(): with autocast(): targets = batch.get('target_output') tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state, nce=opt.nce) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model) loss_data = loss_dict['data'] loss = loss_dict[ 'loss'] # a little trick to avoid gradient overflow with fp16 full_loss = loss if opt.ctc_loss > 0.0: ctc_loss = self.ctc_loss_function(outputs, targets) ctc_loss_data = ctc_loss.item() full_loss = full_loss + opt.ctc_loss * ctc_loss if opt.mirror_loss: rev_loss = loss_dict['rev_loss'] rev_loss_data = loss_dict['rev_loss_data'] mirror_loss = loss_dict['mirror_loss'] full_loss = full_loss + rev_loss + mirror_loss mirror_loss_data = loss_dict['mirror_loss'].item() else: rev_loss_data = None mirror_loss_data = 0 # reconstruction loss if opt.reconstruct: rec_loss = loss_dict['rec_loss'] rec_loss = rec_loss full_loss = full_loss + rec_loss rec_loss_data = loss_dict['rec_loss_data'] else: rec_loss_data = None if opt.lfv_multilingual: lid_logits = outputs['lid_logits'] lid_labels = batch.get('target_lang') lid_loss_function = self.loss_function.get_loss_function( 'lid_loss') lid_loss = lid_loss_function( lid_logits, lid_labels) full_loss = full_loss + lid_loss optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_div) # grad scaler has to be done outside of the autocast self.grad_scaler.scale(full_loss).backward() del outputs except RuntimeError as e: if 'out of memory' in str(e): print('[WARNING]: ran out of memory on GPU %d' % self.rank, flush=True) oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() raise e else: raise e batch_size = batch.size src_size = batch.src_size tgt_size = batch.tgt_size num_accumulated_words.add_(tgt_size) num_accumulated_sents.add_(batch_size) # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency: update_flag = True elif i == n_samples - 1: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency self.all_reduce(num_accumulated_words, op=dist.ReduceOp.SUM, group=self.group) grad_denom = 1.0 / grad_div if self.opt.normalize_gradient: grad_denom = num_accumulated_words.item() * grad_denom else: grad_denom = 1 # the gradient is scaled by world size, so in order to match the model without multiGPU # we rescale the model parameters w.r.t the world size grad_denom = grad_denom / self.world_size # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(self.model.parameters(), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: self.grad_scaler.unscale_(self.optim.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.max_grad_norm) self.optim.step(scaler=self.grad_scaler) self.grad_scaler.update() # self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words.zero_() num_accumulated_sents.zero_() num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) if self.is_main(): print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) num_words = tgt_size report_loss.add_(loss_data) report_tgt_words.add_(num_words) report_src_words.add_(src_size) total_loss.add_(loss_data) total_words.add_(num_words) # total_tokens += batch.get('target_output').nelement() # total_non_pads += batch.get('target_output').ne(onmt.constants.PAD).sum().item() # batch_efficiency = total_non_pads / total_tokens if opt.reconstruct: report_rec_loss.add_(rec_loss_data) if opt.mirror_loss: report_rev_loss.add_(rev_loss_data) report_mirror_loss.add_(mirror_loss_data) if opt.ctc_loss > 0.0: report_ctc_loss.add_(ctc_loss_data) # control the index a little bit to ensure the log is always printed if i == 0 or ((i + 1) % opt.log_interval < self.world_size): self.all_reduce(report_loss, op=dist.ReduceOp.SUM, group=self.group) self.all_reduce(report_tgt_words, op=dist.ReduceOp.SUM, group=self.group) self.all_reduce(report_src_words, op=dist.ReduceOp.SUM, group=self.group) if self.is_main(): log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " % (epoch, i + 1, len(data_iterator), math.exp(report_loss.item() / report_tgt_words.item()))) if opt.reconstruct: self.all_reduce(report_rec_loss, op=dist.ReduceOp.SUM, group=self.group) rec_ppl = math.exp(report_rec_loss.item() / report_src_words.item()) log_string += (" rec_ppl: %6.2f ; " % rec_ppl) if opt.mirror_loss: self.all_reduce(report_rev_loss, op=dist.ReduceOp.SUM, group=self.group) rev_ppl = math.exp(report_rev_loss.item() / report_tgt_words.item()) log_string += (" rev_ppl: %6.2f ; " % rev_ppl) log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words)) if opt.ctc_loss > 0.0: # if torch.isinf(report_ctc_loss): # report_ctc_loss.zero_() # self.all_reduce(report_ctc_loss, op=dist.ReduceOp.SUM, group=self.group) ctc_loss = report_ctc_loss.item( ) / report_tgt_words.item() log_string += (" ctcloss: %8.2f ; " % ctc_loss) log_string += ( "lr: %.7f ; updates: %7d; " % (self.optim.getLearningRate(), self.optim._step)) log_string += ( "%5.0f src tok/s; %5.0f tgt tok/s; " % (report_src_words.item() / (time.time() - start), report_tgt_words.item() / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) self.print(log_string, flush=True) report_loss.zero_() report_tgt_words.zero_() report_src_words.zero_() report_rec_loss.zero_() report_rev_loss.zero_() report_mirror_loss.zero_() report_ctc_loss.zero_() start = time.time() # increase i by world size i = i + self.world_size return total_loss / total_words
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() dataset = train_data data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_tokens, total_loss, total_words = 0, 0, 0 total_non_pads = 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 report_sents = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_log_prior = 0 report_log_variational_posterior = 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 update_counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 nan = False nan_counter = 0 if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None i = data_iterator.iterations_in_epoch while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) batch = next(epoch_iterator) batch = rewrap(batch) grad_scaler = self.opt.batch_size_words if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.data.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model) loss_data = loss_dict['data'] loss = loss_dict[ 'loss'] # a little trick to avoid gradient overflow with fp16 log_prior = self.model.log_prior() log_variational_posterior = self.model.log_variational_posterior( ) # the coeff starts off at 1 for each epoch # from BBB paper: The first mini batches in each epoch have large KL coeff # # the later minibatches are influenced by the data # denom = math.pow(1.5, min(32, update_counter)) # min_coeff = 1 / (self.opt.model_size ** 2) # kl_coeff = max(1 / denom, min_coeff) kl_coeff = 1 / (batch.tgt_size * opt.update_frequency) # kl_coeff = 1 / (self.opt.model_size ** 2) # kl_coeff = 1 full_loss = loss + kl_coeff * (log_variational_posterior - log_prior) # print(log_variational_posterior, log_prior) if opt.mirror_loss: rev_loss = loss_dict['rev_loss'] rev_loss_data = loss_dict['rev_loss_data'] mirror_loss = loss_dict['mirror_loss'] full_loss = full_loss + rev_loss + mirror_loss mirror_loss_data = loss_dict['mirror_loss'].item() else: rev_loss = None rev_loss_data = None mirror_loss_data = 0 # reconstruction loss if opt.reconstruct: rec_loss = loss_dict['rec_loss'] rec_loss = rec_loss full_loss = full_loss + rec_loss rec_loss_data = loss_dict['rec_loss_data'] else: rec_loss_data = None optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() num_accumulated_words = 0 num_accumulated_sents = 0 nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif 0 < opt.batch_size_update <= num_accumulated_words: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler if self.opt.normalize_gradient: grad_denom = num_accumulated_words * grad_denom else: grad_denom = 1 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step update_counter += 1 if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) num_words = tgt_size report_loss += loss_data report_log_prior += log_prior.item() report_log_variational_posterior += log_variational_posterior.item( ) report_tgt_words += num_words report_src_words += src_size report_sents += 1 total_loss += loss_data total_words += num_words total_tokens += batch.get('target_output').nelement() total_non_pads += batch.get('target_output').ne( onmt.constants.PAD).sum().item() optim = self.optim batch_efficiency = total_non_pads / total_tokens if opt.reconstruct: report_rec_loss += rec_loss_data if opt.mirror_loss: report_rev_loss += rev_loss_data report_mirror_loss += mirror_loss_data if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " % (epoch, i + 1, len(data_iterator), math.exp(report_loss / report_tgt_words))) kl_div = report_log_variational_posterior - report_log_prior log_string += ("KL q||p: %6.2f ; " % (kl_div / report_sents)) if opt.reconstruct: rec_ppl = math.exp(report_rec_loss / report_src_words.item()) log_string += (" rec_ppl: %6.2f ; " % rec_ppl) if opt.mirror_loss: rev_ppl = math.exp(report_rev_loss / report_tgt_words) log_string += (" rev_ppl: %6.2f ; " % rev_ppl) # mirror loss per word log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) log_string += ("%5.0f src/s; %5.0f tgt/s; " % (report_src_words / (time.time() - start), report_tgt_words / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_words, report_src_words = 0, 0 report_sents = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_log_prior, report_log_variational_posterior = 0, 0 start = time.time() i = i + 1 return total_loss / total_words
def train_epoch(self, epoch, resume=False, batch_order=None, iteration=0): opt = self.opt train_data = self.train_data # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() if resume: train_data.batch_order = batch_order train_data.set_index(iteration) print("Resuming from iteration: %d" % iteration) else: batch_order = train_data.create_order() iteration = 0 total_loss, total_words = 0, 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 start = time.time() n_samples = len(train_data) counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 for i in range(iteration, n_samples): curriculum = (epoch < opt.curriculum) batches = [train_data.next(curriculum=curriculum)[0]] if (len(self.additional_data) > 0 and i % self.additional_data_ratio[0] == 0): for j in range(len(self.additional_data)): for k in range(self.additional_data_ratio[j + 1]): if self.additional_data_iteration[j] == len( self.additional_data[j]): self.additional_data_iteration[j] = 0 self.additional_data[j].shuffle() self.additional_batch_order[ j] = self.additional_data[j].create_order() batches.append(self.additional_data[j].next()[0]) self.additional_data_iteration[j] += 1 for b in range(len(batches)): batch = batches[b] if self.cuda: batch.cuda(fp16=self.opt.fp16) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.data.ne(onmt.Constants.PAD) outputs = self.model(batch, target_masking=tgt_mask) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model) loss_data = loss_dict['data'] loss = loss_dict['loss'] optimizer = self.optim.optimizer with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch' ) oom = True torch.cuda.empty_cache() else: raise e if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches # simulating the multi-gpu situation # if counter == opt.virtual_gpu: # if counter >= opt.batch_size_update: if num_accumulated_words >= opt.batch_size_update * 0.95: grad_denom = 1 if self.opt.normalize_gradient: grad_denom = num_accumulated_words normalize_gradients( apex.amp.master_params(optimizer), grad_denom) # Update the parameters. self.optim.step(grad_denom=grad_denom) self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ( (float(i) + 1.) / n_samples) self.save(ep, valid_ppl, batch_order=batch_order, iteration=i) num_words = tgt_size report_loss += loss_data report_tgt_words += num_words report_src_words += src_size total_loss += loss_data total_words += num_words optim = self.optim if b == 0 and (i == 0 or (i % opt.log_interval == -1 % opt.log_interval)): print(( "Epoch %2d, %5d/%5d; ; ppl: %6.2f ; lr: %.7f ; num updates: %7d " + "%5.0f src tok/s; %5.0f tgt tok/s; %s elapsed") % (epoch, i + 1, len(train_data), math.exp(report_loss / report_tgt_words), optim.getLearningRate(), optim._step, report_src_words / (time.time() - start), report_tgt_words / (time.time() - start), str( datetime.timedelta( seconds=int(time.time() - self.start_time))))) report_loss, report_tgt_words = 0, 0 report_src_words = 0 start = time.time() return total_loss / total_words