def prepare_sample(batch, device=None, fp16=False): # TODO: sample is a Batch object. This function probably batch = rewrap(batch) batch.cuda(fp16=fp16, device=device) # pass # for i, t in enumerate(sample): # sample[i] = Variable(t.cuda(device=device)) # return batch
def prepare_sample(batch, device=None): """ Put minibatch on the corresponding GPU :param batch: :param device: :return: """ if isinstance(batch, list): batch = batch[0] batch = rewrap(batch) batch.cuda(fp16=False, device=device) return batch
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() dataset = train_data # data iterator: object that controls the # data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed, # num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_tokens, total_loss, total_words = 0, 0, 0 total_non_pads = 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 report_ctc_loss = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 grad_scaler = 1 nan = False nan_counter = 0 if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) # if opt.streaming: # if train_data.is_new_stream(): # streaming_state = self.model.init_stream() # else: # streaming_state = None oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state, nce=opt.nce) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model, vocab_mask=batch.vocab_mask) loss_data = loss_dict['data'] loss = loss_dict[ 'loss'] # a little trick to avoid gradient overflow with fp16 full_loss = loss if opt.ctc_loss > 0.0: ctc_loss = self.ctc_loss_function(outputs, targets) ctc_loss_data = ctc_loss.item() full_loss = full_loss + opt.ctc_loss * ctc_loss report_ctc_loss += ctc_loss_data if opt.mirror_loss: rev_loss = loss_dict['rev_loss'] rev_loss_data = loss_dict['rev_loss_data'] mirror_loss = loss_dict['mirror_loss'] full_loss = full_loss + rev_loss + mirror_loss mirror_loss_data = loss_dict['mirror_loss'].item() else: rev_loss_data = None mirror_loss_data = 0 # reconstruction loss if opt.reconstruct: rec_loss = loss_dict['rec_loss'] rec_loss = rec_loss full_loss = full_loss + rec_loss rec_loss_data = loss_dict['rec_loss_data'] else: rec_loss_data = None # # if opt.lfv_multilingual: # lid_logits = outputs['lid_logits'] # lid_labels = batch.get('target_lang') # lid_loss_function = self.loss_function.get_loss_function('lid_loss') # lid_loss = lid_loss_function(lid_logits, lid_labels) # full_loss = full_loss + lid_loss optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() del outputs except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() num_accumulated_words = 0 num_accumulated_sents = 0 nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) counter = 0 else: nan_counter = 0 if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif 0 < opt.batch_size_update <= num_accumulated_words: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency grad_denom = 1 / grad_scaler if self.opt.normalize_gradient: grad_denom = num_accumulated_words * grad_denom # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) num_words = tgt_size report_loss += loss_data report_tgt_words += num_words report_src_words += src_size total_loss += loss_data total_words += num_words total_tokens += batch.get('target_output').nelement() total_non_pads += batch.get('target_output').ne( onmt.constants.PAD).sum().item() optim = self.optim batch_efficiency = total_non_pads / total_tokens if opt.reconstruct: report_rec_loss += rec_loss_data if opt.mirror_loss: report_rev_loss += rev_loss_data report_mirror_loss += mirror_loss_data if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " % (epoch, i + 1, len(data_iterator), math.exp(report_loss / report_tgt_words))) if opt.reconstruct: rec_ppl = math.exp(report_rec_loss / report_src_words.item()) log_string += (" rec_ppl: %6.2f ; " % rec_ppl) if opt.mirror_loss: rev_ppl = math.exp(report_rev_loss / report_tgt_words) log_string += (" rev_ppl: %6.2f ; " % rev_ppl) # mirror loss per word log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " % (report_src_words / (time.time() - start), report_tgt_words / (time.time() - start))) if opt.ctc_loss > 0.0: ctc_loss = report_ctc_loss / report_tgt_words log_string += (" ctcloss: %8.2f ; " % ctc_loss) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_words, report_src_words = 0, 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_ctc_loss = 0 start = time.time() i = i + 1 return total_loss / total_words
def eval(self, data): total_loss = 0 total_words = 0 opt = self.opt self.model.eval() self.loss_function.eval() self.model.reset_states() # the data iterator creates an epoch iterator data_iterator = generate_data_iterator(data, seed=self.opt.seed, num_workers=opt.num_workers, epoch=1, buffer_size=opt.buffer_size) epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False) if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None """ PyTorch semantics: save space by not creating gradients """ data_size = len(epoch_iterator) i = 0 with torch.no_grad(): # for i in range(len()): while not data_iterator.end_of_epoch(): # batch = data.next()[0] batch = next(epoch_iterator) if isinstance(batch, list): batch = batch[0] batch = rewrap(batch) if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) """ outputs can be either hidden states from decoder or prob distribution from decoder generator """ targets = batch.get('target_output') tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, mirror=opt.mirror_loss, streaming_state=streaming_state, nce=opt.nce) if opt.streaming: streaming_state = outputs['streaming_state'] outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model, eval=True, vocab_mask=batch.vocab_mask) loss_data = loss_dict['data'] total_loss += loss_data total_words += batch.tgt_size i = i + 1 self.model.train() self.loss_function.train() return total_loss / total_words
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model_ae.train() self.loss_function_ae.train() self.lat_dis.train() self.loss_lat_dis.train() # Clear the gradients of the model # self.runner.zero_grad() self.model_ae.zero_grad() self.lat_dis.zero_grad() dataset = train_data data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_loss_ae, total_loss_lat_dis, total_frames, total_loss_adv = 0, 0, 0, 0 report_loss_ae, report_loss_lat_dis, report_loss_adv, report_tgt_frames, report_sent = 0, 0, 0, 0, 0 report_dis_frames, report_adv_frames = 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 step = 0 num_accumulated_sents = 0 grad_scaler = -1 nan = False nan_counter = 0 n_step_ae = opt.update_frequency n_step_lat_dis = opt.update_frequency mode_ae = True loss_lat_dis = 0.0 i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded # while not data_iterator.end_of_epoch(): while True: if data_iterator.end_of_epoch(): data_iterator = generate_data_iterator( dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) batch_size = batch.size if grad_scaler == -1: grad_scaler = 1 # if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility # targets = batch.get('target_output') # tgt_mask = targets.ne(onmt.constants.PAD) if mode_ae: step = self.optim_ae._step loss_ae, loss_adv, encoder_outputs = self.autoencoder_backward( batch, step) else: loss_lat_dis, encoder_outputs = self.lat_dis_backward( batch) except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 else: raise e if loss_ae != loss_ae: # catching NAN problem oom = True self.model_ae.zero_grad() self.optim_ae.zero_grad() # self.lat_dis.zero_grad() # self.optim_lat_dis.zero_grad() nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size if mode_ae: report_adv_frames += encoder_outputs['src_mask'].sum( ).item() report_loss_adv += loss_adv else: report_dis_frames += encoder_outputs['src_mask'].sum( ).item() report_loss_lat_dis += loss_lat_dis counter = counter + 1 # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler else: grad_denom = 1.0 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients( amp.master_params(self.optim_ae.optimizer), grad_denom) normalize_gradients( amp.master_params(self.optim_lat_dis.optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim_ae.optimizer), self.opt.max_grad_norm) # torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim_lat_dis.optimizer), # self.opt.max_grad_norm) torch.nn.utils.clip_grad_value_( amp.master_params(self.optim_lat_dis.optimizer), 0.01) if mode_ae: self.optim_ae.step() self.optim_ae.zero_grad() self.model_ae.zero_grad() self.optim_lat_dis.zero_grad() self.lat_dis.zero_grad() else: self.optim_lat_dis.step() self.optim_lat_dis.zero_grad() self.lat_dis.zero_grad() self.optim_ae.zero_grad() self.model_ae.zero_grad() num_updates = self.optim_ae._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every and mode_ae: # if mode_ae is not here then it will continuously save valid_loss_ae, valid_loss_lat_dis = self.eval( self.valid_data) print('Validation loss ae: %g' % valid_loss_ae) print('Validation loss latent discriminator: %g' % valid_loss_lat_dis) self.save(0, valid_loss_ae, itr=data_iterator) if num_updates == 1000000: break mode_ae = not mode_ae counter = 0 # num_accumulated_words = 0 grad_scaler = -1 num_updates = self.optim_ae._step report_loss_ae += loss_ae # report_tgt_words += num_words num_accumulated_sents += batch_size report_sent += batch_size total_frames += src_size report_tgt_frames += src_size total_loss_ae += loss_ae total_loss_lat_dis += loss_lat_dis total_loss_adv += loss_adv optim_ae = self.optim_ae optim_lat_dis = self.optim_lat_dis # batch_efficiency = total_non_pads / total_tokens if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ( "loss_ae : %6.2f ; loss_lat_dis : %6.2f, loss_adv : %6.2f " % (report_loss_ae / report_tgt_frames, report_loss_lat_dis / (report_dis_frames + 1e-5), report_loss_adv / (report_adv_frames + 1e-5))) log_string += ( "lr_ae: %.7f ; updates: %7d; " % (optim_ae.getLearningRate(), optim_ae._step)) log_string += ( "lr_lat_dis: %.7f ; updates: %7d; " % (optim_lat_dis.getLearningRate(), optim_lat_dis._step)) # log_string += ("%5.0f src tok/s " % (report_tgt_frames / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss_ae = 0 report_loss_lat_dis = 0 report_loss_adv = 0 report_tgt_frames = 0 report_dis_frames = 0 report_adv_frames = 0 report_sent = 0 start = time.time() i = i + 1 return total_loss_ae / total_frames * 100, total_loss_lat_dis / n_samples * 100, total_loss_adv / n_samples * 100
def eval(self, data): total_loss_ae = 0 total_loss_lat_dis = 0 total_tgt_frames = 0 total_src = 0 total_sent = 0 opt = self.opt self.model_ae.eval() self.loss_function_ae.eval() self.lat_dis.eval() self.loss_lat_dis.eval() # self.model.reset_states() # the data iterator creates an epoch iterator data_iterator = generate_data_iterator(data, seed=self.opt.seed, num_workers=opt.num_workers, epoch=1, buffer_size=opt.buffer_size) epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False) """ PyTorch semantics: save space by not creating gradients """ data_size = len(epoch_iterator) # print(data_size) i = 0 with torch.no_grad(): # for i in range(len()): while not data_iterator.end_of_epoch(): # batch = data.next()[0] batch = next(epoch_iterator) if isinstance(batch, list): batch = batch[0] batch = rewrap(batch) if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) """ outputs can be either hidden states from decoder or prob distribution from decoder generator """ encoder_outputs, decoder_outputs = self.model_ae(batch) gate_padded = batch.get('gate_padded') if self.opt.n_frames_per_step > 1: slice = torch.arange(self.opt.n_frames_per_step - 1, gate_padded.size(1), self.opt.n_frames_per_step) gate_padded = gate_padded[:, slice] src_org = batch.get('source_org') src_org = src_org.narrow(2, 1, src_org.size(2) - 1) target = [src_org.permute(1, 2, 0).contiguous(), gate_padded] loss_ae = self.loss_function_ae(decoder_outputs, target) loss_ae_data = loss_ae.data.item() preds = self.lat_dis(encoder_outputs['context']) loss_lat_dis = self.loss_lat_dis( preds, batch.get('source_lang'), mask=encoder_outputs['src_mask'], adversarial=False) total_src += encoder_outputs['src_mask'].float().sum().item() loss_lat_dis_data = loss_lat_dis.data.item() total_loss_ae += loss_ae_data total_loss_lat_dis += loss_lat_dis_data total_tgt_frames += batch.src_size total_sent += batch.size i = i + 1 return total_loss_ae / total_tgt_frames, total_loss_lat_dis / total_src * 100
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() dataset = train_data data_iterator = generate_data_iterator(dataset, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_loss, total_frames = 0, 0 report_loss, report_tgt_frames, report_sent = 0, 0, 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 num_accumulated_sents = 0 grad_scaler = -1 nan = False nan_counter = 0 i = data_iterator.iterations_in_epoch if not isinstance( train_data, list) else epoch_iterator.n_yielded while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) # this batch generator is not very clean atm batch = next(epoch_iterator) if isinstance(batch, list) and self.n_gpus == 1: batch = batch[0] batch = rewrap(batch) if grad_scaler == -1: grad_scaler = 1 # if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility # targets = batch.get('target_output') # tgt_mask = targets.ne(onmt.constants.PAD) outputs = self.model(batch) gate_padded = batch.get('gate_padded') if self.opt.n_frames_per_step > 1: slice = torch.arange(0, gate_padded.size(1), self.opt.n_frames_per_step) gate_padded = gate_padded[:, slice] src_org = batch.get('source_org') src_org = src_org.narrow(2, 1, src_org.size(2) - 1) target = [src_org.permute(1, 2, 0).contiguous(), gate_padded] loss = self.loss_function(outputs, target) batch_size = batch.size loss_data = loss.data.item() # a little trick to avoid gradient overflow with fp16 full_loss = loss optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() del outputs except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size counter = counter + 1 # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler # if self.opt.normalize_gradient: # grad_denom = num_accumulated_words * grad_denom else: grad_denom = 1.0 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 # num_accumulated_words = 0 grad_scaler = -1 num_updates = self.optim._step if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) report_loss += loss_data # report_tgt_words += num_words num_accumulated_sents += batch_size report_sent += batch_size total_frames += src_size report_tgt_frames += src_size total_loss += loss_data optim = self.optim # batch_efficiency = total_non_pads / total_tokens if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ( "Epoch %2d, %5d/%5d; ; loss : %6.2f ; " % (epoch, i + 1, len(data_iterator), report_loss)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) # log_string += ("%5.0f src tok/s " % (report_tgt_frames / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_frames = 0 report_sent = 0 start = time.time() i = i + 1 return total_loss / n_samples * 100
def eval(self, data): total_loss = 0 total_tgt_frames = 0 total_sent = 0 opt = self.opt self.model.eval() self.loss_function.eval() # self.model.reset_states() # the data iterator creates an epoch iterator data_iterator = generate_data_iterator(data, seed=self.opt.seed, num_workers=opt.num_workers, epoch=1, buffer_size=opt.buffer_size) epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False) if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None """ PyTorch semantics: save space by not creating gradients """ data_size = len(epoch_iterator) i = 0 with torch.no_grad(): # for i in range(len()): while not data_iterator.end_of_epoch(): # batch = data.next()[0] batch = next(epoch_iterator) if isinstance(batch, list): batch = batch[0] batch = rewrap(batch) if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) """ outputs can be either hidden states from decoder or prob distribution from decoder generator """ outputs = self.model(batch) gate_padded = batch.get('gate_padded') if self.opt.n_frames_per_step > 1: slice = torch.arange(self.opt.n_frames_per_step - 1, gate_padded.size(1), self.opt.n_frames_per_step) gate_padded = gate_padded[:, slice] src_org = batch.get('source_org') src_org = src_org.narrow(2, 1, src_org.size(2) - 1) target = [src_org.permute(1, 2, 0).contiguous(), gate_padded] loss = self.loss_function(outputs, target) loss_data = loss.data.item() total_loss += loss_data total_tgt_frames += batch.src_size total_sent += batch.size i = i + 1 self.model.train() self.loss_function.train() return total_loss / data_size * 100
def train_epoch(self, epoch, resume=False, itr_progress=None): global rec_ppl opt = self.opt train_data = self.train_data streaming = opt.streaming self.model.train() self.loss_function.train() # Clear the gradients of the model # self.runner.zero_grad() self.model.zero_grad() self.model.reset_states() dataset = train_data data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed, num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size) if resume: data_iterator.load_state_dict(itr_progress) epoch_iterator = data_iterator.next_epoch_itr( not streaming, pin_memory=opt.pin_memory) total_tokens, total_loss, total_words = 0, 0, 0 total_non_pads = 0 report_loss, report_tgt_words = 0, 0 report_src_words = 0 report_sents = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_log_prior = 0 report_log_variational_posterior = 0 start = time.time() n_samples = len(epoch_iterator) counter = 0 update_counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 nan = False nan_counter = 0 if opt.streaming: streaming_state = self.model.init_stream() else: streaming_state = None i = data_iterator.iterations_in_epoch while not data_iterator.end_of_epoch(): curriculum = (epoch < opt.curriculum) batch = next(epoch_iterator) batch = rewrap(batch) grad_scaler = self.opt.batch_size_words if self.opt.update_frequency > 1 else batch.tgt_size if self.cuda: batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed) oom = False try: # outputs is a dictionary containing keys/values necessary for loss function # can be flexibly controlled within models for easier extensibility targets = batch.get('target_output') tgt_mask = targets.data.ne(onmt.constants.PAD) outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask, zero_encoder=opt.zero_encoder, mirror=opt.mirror_loss, streaming_state=streaming_state) batch_size = batch.size outputs['tgt_mask'] = tgt_mask loss_dict = self.loss_function(outputs, targets, model=self.model) loss_data = loss_dict['data'] loss = loss_dict[ 'loss'] # a little trick to avoid gradient overflow with fp16 log_prior = self.model.log_prior() log_variational_posterior = self.model.log_variational_posterior( ) # the coeff starts off at 1 for each epoch # from BBB paper: The first mini batches in each epoch have large KL coeff # # the later minibatches are influenced by the data # denom = math.pow(1.5, min(32, update_counter)) # min_coeff = 1 / (self.opt.model_size ** 2) # kl_coeff = max(1 / denom, min_coeff) kl_coeff = 1 / (batch.tgt_size * opt.update_frequency) # kl_coeff = 1 / (self.opt.model_size ** 2) # kl_coeff = 1 full_loss = loss + kl_coeff * (log_variational_posterior - log_prior) # print(log_variational_posterior, log_prior) if opt.mirror_loss: rev_loss = loss_dict['rev_loss'] rev_loss_data = loss_dict['rev_loss_data'] mirror_loss = loss_dict['mirror_loss'] full_loss = full_loss + rev_loss + mirror_loss mirror_loss_data = loss_dict['mirror_loss'].item() else: rev_loss = None rev_loss_data = None mirror_loss_data = 0 # reconstruction loss if opt.reconstruct: rec_loss = loss_dict['rec_loss'] rec_loss = rec_loss full_loss = full_loss + rec_loss rec_loss_data = loss_dict['rec_loss_data'] else: rec_loss_data = None optimizer = self.optim.optimizer # When the batch size is large, each gradient step is very easy to explode on fp16 # Normalizing the loss to grad scaler ensures this will not happen full_loss.div_(grad_scaler) if self.cuda: with amp.scale_loss(full_loss, optimizer) as scaled_loss: scaled_loss.backward() else: full_loss.backward() except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory on GPU , skipping batch') oom = True torch.cuda.empty_cache() loss = 0 if opt.streaming: # reset stream in this case ... streaming_state = self.model.init_stream() else: raise e if loss != loss: # catching NAN problem oom = True self.model.zero_grad() self.optim.zero_grad() num_accumulated_words = 0 num_accumulated_sents = 0 nan_counter = nan_counter + 1 print("Warning!!! Loss is Nan") if nan_counter >= 15: raise ValueError( "Training stopped because of multiple NaN occurence. " "For ASR, using the Relative Transformer is more stable and recommended." ) else: nan_counter = 0 if not oom: src_size = batch.src_size tgt_size = batch.tgt_size counter = counter + 1 num_accumulated_words += tgt_size num_accumulated_sents += batch_size # We only update the parameters after getting gradients from n mini-batches update_flag = False if counter >= opt.update_frequency > 0: update_flag = True elif 0 < opt.batch_size_update <= num_accumulated_words: update_flag = True elif i == n_samples: # update for the last minibatch update_flag = True if update_flag: # accumulated gradient case, in this case the update frequency if (counter == 1 and self.opt.update_frequency != 1) or counter > 1: grad_denom = 1 / grad_scaler if self.opt.normalize_gradient: grad_denom = num_accumulated_words * grad_denom else: grad_denom = 1 # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler normalize_gradients(amp.master_params(optimizer), grad_denom) # Update the parameters. if self.opt.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.opt.max_grad_norm) self.optim.step() self.optim.zero_grad() self.model.zero_grad() counter = 0 num_accumulated_words = 0 num_accumulated_sents = 0 num_updates = self.optim._step update_counter += 1 if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every: valid_loss = self.eval(self.valid_data) valid_ppl = math.exp(min(valid_loss, 100)) print('Validation perplexity: %g' % valid_ppl) ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples) self.save(ep, valid_ppl, itr=data_iterator) num_words = tgt_size report_loss += loss_data report_log_prior += log_prior.item() report_log_variational_posterior += log_variational_posterior.item( ) report_tgt_words += num_words report_src_words += src_size report_sents += 1 total_loss += loss_data total_words += num_words total_tokens += batch.get('target_output').nelement() total_non_pads += batch.get('target_output').ne( onmt.constants.PAD).sum().item() optim = self.optim batch_efficiency = total_non_pads / total_tokens if opt.reconstruct: report_rec_loss += rec_loss_data if opt.mirror_loss: report_rev_loss += rev_loss_data report_mirror_loss += mirror_loss_data if i == 0 or (i % opt.log_interval == -1 % opt.log_interval): log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " % (epoch, i + 1, len(data_iterator), math.exp(report_loss / report_tgt_words))) kl_div = report_log_variational_posterior - report_log_prior log_string += ("KL q||p: %6.2f ; " % (kl_div / report_sents)) if opt.reconstruct: rec_ppl = math.exp(report_rec_loss / report_src_words.item()) log_string += (" rec_ppl: %6.2f ; " % rec_ppl) if opt.mirror_loss: rev_ppl = math.exp(report_rev_loss / report_tgt_words) log_string += (" rev_ppl: %6.2f ; " % rev_ppl) # mirror loss per word log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words)) log_string += ("lr: %.7f ; updates: %7d; " % (optim.getLearningRate(), optim._step)) log_string += ("%5.0f src/s; %5.0f tgt/s; " % (report_src_words / (time.time() - start), report_tgt_words / (time.time() - start))) log_string += ("%s elapsed" % str( datetime.timedelta(seconds=int(time.time() - self.start_time)))) print(log_string) report_loss = 0 report_tgt_words, report_src_words = 0, 0 report_sents = 0 report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0 report_log_prior, report_log_variational_posterior = 0, 0 start = time.time() i = i + 1 return total_loss / total_words