def report_func(epoch, batch, num_batches, start_time, lr, report_stats): """ This is the user-defined batch-level traing progress report function. Args: epoch(int): current epoch count. batch(int): current batch count. num_batches(int): total number of batches. start_time(float): last report time. lr(float): current learning rate. report_stats(Statistics): old Statistics instance. Returns: report_stats(Statistics): updated Statistics instance. """ if batch % REPORT_EVERY == -1 % REPORT_EVERY: report_stats.output(epoch, batch+1, num_batches, start_time) report_stats = onmt.Statistics() return report_stats
def sharded_compute_loss(self, batch, output, attns, cur_trunc, trunc_size, shard_size, teacher_outputs=None): """ Compute the loss in shards for efficiency. """ batch_stats = onmt.Statistics() range_ = (cur_trunc, cur_trunc + trunc_size) gen_state = make_gen_state(output, batch, attns, range_, self.copy_attn, teacher_outputs) for shard in shards(gen_state, shard_size): loss, stats = self.compute_loss(batch, **shard) loss.div(batch.batch_size).backward(retain_graph=True) batch_stats.update(stats) return batch_stats
def _stats(self, loss, xent, kl, scores, target, sample_xents): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. scores (:obj:`FloatTensor`): a score for each possible output target (:obj:`FloatTensor`): true targets Returns: :obj:`Statistics` : statistics for this batch. """ pred = scores.max(1)[1] non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target) \ .masked_select(non_padding) \ .long().sum() return onmt.Statistics(loss.cpu().numpy(), xent.cpu().numpy(), kl.cpu().numpy(), non_padding.long().sum().cpu().numpy(), num_correct.cpu().numpy(), sample_xents.cpu().numpy())
def sharded_compute_loss(self, batch, output, attns, cur_trunc, trunc_size, shard_size, output_t, attns_t): """ Compute the loss in shards for efficiency. """ batch_stats = onmt.Statistics() range_ = (cur_trunc, cur_trunc + trunc_size) gen_state = make_gen_state(output, batch, attns, range_, self.copy_attn) gen_state_t = make_gen_state(output_t, batch, attns_t, range_, self.copy_attn) #loss, stats = self.compute_loss(batch, **gen_state) loss, stats = self.compute_loss(batch, gen_state['output'], gen_state['target'], gen_state_t['output']) loss.div(batch.batch_size).backward() batch_stats.update(stats) #for shard in shards(gen_state, shard_size): # loss, stats = self.compute_loss(batch, **shard) # loss.div(batch.batch_size).backward() # batch_stats.update(stats) return batch_stats
def sharded_compute_loss(self, batch, output, attns, cur_trunc, trunc_size, shard_size, normalization): """Compute the forward loss and backpropagate. Computation is done with shards and optionally truncation for memory efficiency. Also supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from `(cur_trunc, cur_trunc + trunc_size)`. Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers. Args: batch (batch) : batch of labeled examples output (:obj:`FloatTensor`) : output of decoder model `[tgt_len x batch x hidden]` attns (dict) : dictionary of attention distributions `[tgt_len x batch x src_len]` cur_trunc (int) : starting position of truncation window trunc_size (int) : length of truncation window shard_size (int) : maximum number of examples in a shard normalization (int) : Loss is divided by this number Returns: :obj:`onmt.Statistics`: validation loss statistics """ batch_stats = onmt.Statistics() range_ = (cur_trunc, cur_trunc + trunc_size) shard_state = self._make_shard_state(batch, output, range_, attns) for shard in shards(shard_state, shard_size): loss, stats = self._compute_loss(batch, **shard) loss.div(normalization).backward() batch_stats.update(stats) return batch_stats
def report_func(epoch, batch, num_batches, start_time, lr, report_stats, opt): """ This is the user-defined batch-level traing progress report function. Args: epoch(int): current epoch count. batch(int): current batch count. num_batches(int): total number of batches. start_time(float): last report time. lr(float): current learning rate. report_stats(Statistics): old Statistics instance. Returns: report_stats(Statistics): updated Statistics instance. """ if batch % opt.report_every == -1 % opt.report_every: report_stats.output(epoch, batch + 1, num_batches, start_time) if opt.exp_host: report_stats.log("progress", experiment, lr) report_stats = onmt.Statistics() return report_stats
def report_func(epoch, batch, num_batches, progress_step, start_time, lr, report_stats): ''' This is the user-defined batch-level training progress report function. :param epoch(int): current epoch count. :param batch(int): current batch count. :param num_batches(int): total number of batches. :param progress_step(int): the progress time. :param start_time(float): last report time. :param lr(float): current learning rate. :param report_stats(Statistics): old Statistics instance. :return: ''' if batch % opt.report_every == -1 % opt.report_every: report_stats.output(epoch, batch + 1, num_batches, start_time) if opt.tensorboard: # Log the progress using the number of batches on the x-axis. report_stats.log_tensorboard('progress', writer, lr, progress_step) report_stats = onmt.Statistics() return report_stats
def sharded_compute_loss(self, batch, output, kappa_output, attns, encoder_output, kappa_encoder_output, decoder_output_wod, kappa_decoder_output_wod, cur_trunc, trunc_size, shard_size): """ Compute the loss in shards for efficiency. """ batch_stats = onmt.Statistics() range_ = (cur_trunc, cur_trunc + trunc_size) shard_state = self.make_shard_state(batch, range_, output, kappa_output, encoder_output, kappa_encoder_output, decoder_output_wod, kappa_decoder_output_wod, attns) #print(shard_state) for shard in shards(shard_state, shard_size): loss, stats = self.compute_loss(batch, **shard) loss.div(batch.batch_size).backward() batch_stats.update(stats) return batch_stats
def sharded_compute_loss(self, batch, output, sample_outputs, attns, sample_attns, sample_batch_tgt, sample_batch_alignment, cur_trunc, trunc_size, shard_size, normalization, backward=True, rewards=None): """Compute the forward loss and backpropagate. Computation is done with shards and optionally truncation for memory efficiency. Also supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from `(cur_trunc, cur_trunc + trunc_size)`. Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers. Args: batch (batch) : batch of labeled examples output (:obj:`FloatTensor`) : output of decoder model `[tgt_len x batch x hidden]` attns (dict) : dictionary of attention distributions `[tgt_len x batch x src_len]` cur_trunc (int) : starting position of truncation window trunc_size (int) : length of truncation window shard_size (int) : maximum number of examples in a shard normalization (int) : Loss is divided by this number Returns: :obj:`onmt.Statistics`: validation loss statistics """ assert rewards is not None def shard_compute(batch, ml_shard_state, rl_shard_state, shard_size, batch_stats): # input() return loss_list, batch_stats batch_stats = onmt.Statistics() range_ = (cur_trunc, cur_trunc + trunc_size) ml_shard_state = self._make_shard_state(batch.tgt, batch.alignment, output, range_, attns) rl_shard_state = self._make_shard_state(sample_batch_tgt, sample_batch_alignment, sample_outputs, range_, sample_attns, rewards) shards = onmt.Loss.shards for ml_shard, rl_shard in zip(shards(ml_shard_state, shard_size), shards(rl_shard_state, shard_size)): # print("Loss, line:123", shard) ml_loss, stats = self.ml_loss_compute._compute_loss( batch, **ml_shard) # be able to use same batch becuase subprocess doesn't use batch.tgt or batch.alignment rl_loss, stats = self.rl_loss_compute._compute_loss( batch, **rl_shard) if backward: loss = (1 - self.apply_factor) * ml_loss.div( normalization) + self.apply_factor * rl_loss.div( normalization) loss.backward() batch_stats.update(stats) return batch_stats
def optimize_quantization_points(modelToQuantize, train_loader, test_loader, options, optim=None, numPointsPerTensor=16, assignBitsAutomatically=False, use_distillation_loss=False, bucket_size=None): print('Preparing training - pre processing tensors') if options is None: options = onmt.standard_options.stdOptions if not isinstance(options, dict): options = mhf.convertToDictionary(options) options = handle_options(options) options = mhf.convertToNamedTuple(options) modelToQuantize.eval() quantizedModel = copy.deepcopy(modelToQuantize) fields = train_loader.dataset.fields train_loss = make_loss_compute(quantizedModel, fields["tgt"].vocab, train_loader.dataset, options.copy_attn, options.copy_attn_force) valid_loss = make_loss_compute(quantizedModel, fields["tgt"].vocab, test_loader.dataset, options.copy_attn, options.copy_attn_force) trunc_size = options.truncated_decoder # Badly named... shard_size = options.max_generator_batches numTensorsNetwork = sum(1 for _ in quantizedModel.parameters()) if isinstance(numPointsPerTensor, int): numPointsPerTensor = [numPointsPerTensor] * numTensorsNetwork if len(numPointsPerTensor) != numTensorsNetwork: raise ValueError( 'numPointsPerTensor must be equal to the number of tensor in the network' ) scalingFunction = quantization.ScalingFunction(type_scaling='linear', max_element=False, subtract_mean=False, modify_in_place=False, bucket_size=bucket_size) quantizedModel.zero_grad() dummy_optim = create_optimizer( quantizedModel, options) #dummy optim, just to pass to trainer if assignBitsAutomatically: trainer = thf.MyTrainer(quantizedModel, train_loader, test_loader, train_loss, valid_loss, dummy_optim, trunc_size, shard_size) batch = next(iter(train_loader)) quantizedModel.zero_grad() trainer.forward_and_backward(0, batch, 0, onmt.Statistics(), None) fisherInformation = [] for p in quantizedModel.parameters(): fisherInformation.append(p.grad.data.norm()) numPointsPerTensor = qhf.assign_bits_automatically(fisherInformation, numPointsPerTensor, input_is_point=True) quantizedModel.zero_grad() del trainer del optim # initialize the points using the percentile function so as to make them all usable pointsPerTensor = [] for idx, p in enumerate(quantizedModel.parameters()): initial_points = qhf.initialize_quantization_points( p.data, scalingFunction, numPointsPerTensor[idx]) initial_points = Variable(initial_points, requires_grad=True) # do a dummy backprop so that the grad attribute is initialized. We need this because we call # the .backward() function manually later on (since pytorch can't assign variables to model # parameters) initial_points.sum().backward() pointsPerTensor.append(initial_points) optionsOpt = copy.deepcopy(mhf.convertToDictionary(options)) optimizer = create_optimizer(pointsPerTensor, mhf.convertToNamedTuple(optionsOpt)) trainer = thf.MyTrainer(quantizedModel, train_loader, test_loader, train_loss, valid_loss, dummy_optim, trunc_size, shard_size) perplexity_epochs = [] quantizationFunctions = [] for idx, p in enumerate(modelToQuantize.parameters()): #efficient version of nonUniformQuantization quant_fun = quantization.nonUniformQuantization_variable( max_element=False, subtract_mean=False, modify_in_place=False, bucket_size=bucket_size, pre_process_tensors=True, tensor=p.data) quantizationFunctions.append(quant_fun) print('Pre processing done, training started') for epoch in range(options.start_epoch, options.epochs + 1): train_stats = onmt.Statistics() quantizedModel.train() for idx_batch, batch in enumerate(train_loader): #zero the gradient quantizedModel.zero_grad() # quantize the weights for idx, p_quantized in enumerate(quantizedModel.parameters()): #I am using the efficient version of nonUniformQuantization. The tensors (that don't change across #iterations) are saved inside the quantization function, and we only need to pass the quantization #points p_quantized.data = quantizationFunctions[idx].forward( None, pointsPerTensor[idx].data) trainer.forward_and_backward(idx_batch, batch, epoch, train_stats, report_func, use_distillation_loss, modelToQuantize) # now get the gradient of the pointsPerTensor for idx, p in enumerate(quantizedModel.parameters()): pointsPerTensor[idx].grad.data = quantizationFunctions[ idx].backward(p.grad.data)[1] optimizer.step() # after optimzer.step() we need to make sure that the points are still sorted for points in pointsPerTensor: points.data = torch.sort(points.data)[0] print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) perplexity_epochs.append(valid_stats.ppl()) # 3. Update the learning rate optimizer.updateLearningRate(valid_stats.ppl(), epoch) informationDict = {} informationDict['perplexity'] = perplexity_epochs informationDict[ 'numEpochsTrained'] = options.epochs + 1 - options.start_epoch return pointsPerTensor, informationDict
def train_model(model, train_loader, test_loader, plot_path, optim=None, options=None, stochasticRounding=False, quantizeWeights=False, numBits=8, maxElementAllowedForQuantization=False, bucket_size=None, subtractMeanInQuantization=False, quantizationFunctionToUse='uniformLinearScaling', backprop_quantization_style='none', num_estimate_quant_grad=1, use_distillation_loss=False, teacher_model=None, quantize_first_and_last_layer=True): if options is None: options = copy.deepcopy(onmt.standard_options.stdOptions) if not isinstance(options, dict): options = mhf.convertToDictionary(options) options = handle_options(options) options = mhf.convertToNamedTuple(options) if optim is None: optim = create_optimizer(model, options) if use_distillation_loss is True and teacher_model is None: raise ValueError( 'If training with distilled word level, we need teacher_model to be passed' ) if teacher_model is not None: teacher_model.eval() step_since_last_grad_quant_estimation = 0 num_param_model = sum(1 for _ in model.parameters()) if quantizeWeights: quantizationFunctionToUse = quantizationFunctionToUse.lower() if quantizationFunctionToUse == 'uniformAbsMaxScaling'.lower(): s = 2**(numBits - 1) type_of_scaling = 'absmax' elif quantizationFunctionToUse == 'uniformLinearScaling'.lower(): s = 2**numBits type_of_scaling = 'linear' else: raise ValueError( 'The specified quantization function is not present') if backprop_quantization_style is None or backprop_quantization_style in ( 'none', 'truncated'): quantizeFunctions = lambda x: quantization.uniformQuantization( x, s, type_of_scaling=type_of_scaling, stochastic_rounding=stochasticRounding, max_element=maxElementAllowedForQuantization, subtract_mean=subtractMeanInQuantization, modify_in_place=False, bucket_size=bucket_size)[0] elif backprop_quantization_style == 'complicated': quantizeFunctions = [quantization.uniformQuantization_variable(s, type_of_scaling=type_of_scaling, stochastic_rounding=stochasticRounding, max_element=maxElementAllowedForQuantization, subtract_mean=subtractMeanInQuantization, modify_in_place=False, bucket_size=bucket_size) \ for _ in model.parameters()] else: raise ValueError( 'The specified backprop_quantization_style not recognized') fields = train_loader.dataset.fields # Collect features. src_features = collect_features(train_loader.dataset, fields) for j, feat in enumerate(src_features): print(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) train_loss = make_loss_compute(model, fields["tgt"].vocab, train_loader.dataset, options.copy_attn, options.copy_attn_force, use_distillation_loss, teacher_model) #for validation we don't use distilled loss; it would screw up the perplexity computation valid_loss = make_loss_compute(model, fields["tgt"].vocab, test_loader.dataset, options.copy_attn, options.copy_attn_force) trunc_size = None #options.truncated_decoder # Badly named... shard_size = options.max_generator_batches trn_writer = tbx.SummaryWriter(plot_path + '_output/train') tst_writer = tbx.SummaryWriter(plot_path + '_output/test') trainer = thf.MyTrainer(model, train_loader, test_loader, train_loss, valid_loss, optim, trunc_size, shard_size) perplexity_epochs = [] for epoch in range(options.start_epoch, options.epochs + 1): MAX_Memory = 0 train_stats = onmt.Statistics() model.train() for idx_batch, batch in enumerate(train_loader): model.zero_grad() if quantizeWeights: if step_since_last_grad_quant_estimation >= num_estimate_quant_grad: # we save them because we only want to quantize weights to compute gradients, # but keep using non-quantized weights during the algorithm model_state_dict = model.state_dict() for idx, p in enumerate(model.parameters()): if quantize_first_and_last_layer is False: if idx == 0 or idx == num_param_model - 1: continue if backprop_quantization_style == 'truncated': p.data.clamp_( -1, 1 ) # TODO: Is this necessary? Clamping the weights? if backprop_quantization_style in ('none', 'truncated'): p.data = quantizeFunctions(p.data) elif backprop_quantization_style == 'complicated': p.data = quantizeFunctions[idx].forward(p.data) else: raise ValueError trainer.forward_and_backward(idx_batch, batch, epoch, train_stats, report_func, use_distillation_loss, teacher_model) if quantizeWeights: if step_since_last_grad_quant_estimation >= num_estimate_quant_grad: model.load_state_dict(model_state_dict) del model_state_dict # free memory if backprop_quantization_style in ('truncated', 'complicated'): for idx, p in enumerate(model.parameters()): if quantize_first_and_last_layer is False: if idx == 0 or idx == num_param_model - 1: continue #Now some sort of backward. For the none style, we don't do anything. #for the truncated style, we just need to truncate the grad weights #as per the paper here: https://arxiv.org/pdf/1609.07061.pdf #Complicated is my derivation, but unsure whether to use it or not if backprop_quantization_style == 'truncated': p.grad.data[p.data.abs() > 1] = 0 elif backprop_quantization_style == 'complicated': p.grad.data = quantizeFunctions[idx].backward( p.grad.data) #update parameters after every batch trainer.optim.step() if step_since_last_grad_quant_estimation >= num_estimate_quant_grad: step_since_last_grad_quant_estimation = 0 step_since_last_grad_quant_estimation += 1 print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) trn_writer.add_scalar('ppl', train_stats.ppl(), epoch + 1) trn_writer.add_scalar('acc', train_stats.accuracy(), epoch + 1) # 2. Validate on the validation set. MAX_Memory = max(MAX_Memory, torch.cuda.max_memory_allocated()) valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) print('Max allocated memory: {:2f}MB'.format(MAX_Memory / (1024**2))) perplexity_epochs.append(valid_stats.ppl()) tst_writer.add_scalar('ppl', valid_stats.ppl(), epoch + 1) tst_writer.add_scalar('acc', valid_stats.accuracy(), epoch + 1) # 3. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) if quantizeWeights: for idx, p in enumerate(model.parameters()): if backprop_quantization_style == 'truncated': p.data.clamp_( -1, 1) # TODO: Is this necessary? Clamping the weights? if backprop_quantization_style in ('none', 'truncated'): p.data = quantizeFunctions(p.data) elif backprop_quantization_style == 'complicated': p.data = quantizeFunctions[idx].forward(p.data) del quantizeFunctions[idx].saved_for_backward quantizeFunctions[idx].saved_for_backward = None # free memory else: raise ValueError informationDict = {} informationDict['perplexity'] = perplexity_epochs informationDict[ 'numEpochsTrained'] = options.epochs + 1 - options.start_epoch return model, informationDict
def report_func(epoch, batch, num_batches, progress_step, start_time, lr, report_stats): if batch % opt.report_every == -1 % opt.report_every: report_stats.output(epoch, batch + 1, num_batches, start_time) report_stats = onmt.Statistics() return report_stats
def report_func(trnr, epoch, batch, num_batches, start_time, lr, report_stats, enc_output_dict, dec_output_dict, mem_dict): """ This is the user-defined batch-level traing progress report function. Args: epoch(int): current epoch count. batch(int): current batch count. num_batches(int): total number of batches. start_time(float): last report time. lr(float): current learning rate. report_stats(Statistics): old Statistics instance. Returns: report_stats(Statistics): updated Statistics instance. """ global iteration_log_loss_file global mem_log_file if batch % opt.report_every == -1 % opt.report_every: report_stats.output(epoch, batch+1, num_batches, start_time) if mem_dict: if opt.separate_buffers: enc_used_bits = mem_dict['enc_used_bits'] dec_used_bits = mem_dict['dec_used_bits'] enc_optimal_bits = mem_dict['enc_optimal_bits'] dec_optimal_bits = mem_dict['dec_optimal_bits'] enc_normal_bits = mem_dict['enc_normal_bits'] dec_normal_bits = mem_dict['dec_normal_bits'] enc_actual_ratio = enc_normal_bits / enc_used_bits enc_optimal_ratio = enc_normal_bits / enc_optimal_bits dec_actual_ratio = dec_normal_bits / dec_used_bits dec_optimal_ratio = dec_normal_bits / dec_optimal_bits print("Enc actual memory ratio {}".format(enc_actual_ratio)) print("Enc optimal memory ratio {}".format(enc_optimal_ratio)) print("Dec actual memory ratio {}".format(dec_actual_ratio)) print("Dec optimal memory ratio {}".format(dec_optimal_ratio)) mem_log_file.write('{} {} {} {} {} {}\n'.format( epoch, batch, enc_actual_ratio, enc_optimal_ratio, dec_actual_ratio, dec_optimal_ratio)) mem_log_file.flush() else: used_bits = mem_dict['used_bits'] optimal_bits = mem_dict['optimal_bits'] normal_bits = mem_dict['normal_bits'] actual_ratio = normal_bits / used_bits optimal_ratio = normal_bits / optimal_bits print("Actual memory ratio {}".format(actual_ratio)) print("Optimal memory ratio {}".format(optimal_ratio)) mem_log_file.write('{} {} {} {}\n'.format( epoch, batch, actual_ratio, optimal_ratio)) mem_log_file.flush() if not opt.no_log_during_epoch: valid_stats = trnr.validate() iteration_log_loss_file.write('{} {} {} {} {} {}\n'.format( epoch, batch, report_stats.accuracy(), report_stats.ppl(), valid_stats.accuracy(), valid_stats.ppl())) iteration_log_loss_file.flush() report_stats = onmt.Statistics() return report_stats
def _compute_loss(self, batch, ret, doc_index): #RET obtain the hypothesis top_hyp = ret['predictions'] top_probabilities = ret['out_prob'] #AND ALSO THE REFERENCE tgt = batch.tgt batch_num = tgt.size()[1] # print ('top_hyp ',len(top_hyp)) # print ('top_probabilities ', len(top_probabilities)) # print ('tgt ',tgt.size()) if self.doc_level: sentences_pred, prob_per_word = self.words_probs_from_preds( top_hyp, top_probabilities, doc_index=doc_index, batch_num=batch_num) sentences_gt = self.obtain_words_from_tgt(tgt, doc_index=doc_index) # print (doc_index) # print (prob_per_word.size()) # print (len(sentences_pred[0][0])) # print (len(sentences_pred[0][1])) # print (len(sentences_gt[0])) dl_rewards = 0 if self.bleu_doc: bleu_doc_scores = self.BLEU_score(sentences_pred, sentences_gt) dl_rewards += bleu_doc_scores if self.LC_doc: LC_scores = self.LC_scores(sentences_pred) dl_rewards += LC_scores if self.COH_doc: coher_scores = self.coher_scores(sentences_pred) dl_rewards += coher_scores loss_doc = self.RISK_loss(prob_per_word, dl_rewards) loss_doc = loss_doc.sum(0) if self.bleu_sen: sentences_pred, prob_per_word = self.words_probs_from_preds( top_hyp, top_probabilities) # print (prob_per_word.size()) # print ('i am printing') sentences_gt = self.obtain_words_from_tgt(tgt) sl_rewards = self.BLEU_score(sentences_pred, sentences_gt) loss_sen = self.RISK_loss(prob_per_word, sl_rewards) loss_sen = loss_sen.sum(0) if self.doc_level and self.bleu_sen: loss = loss_doc + loss_sen elif self.doc_level: loss = loss_doc else: loss = loss_sen # #NEED TO COMPUTE THE LOSS # loss = bleu_scores*prob_per_word # loss =loss.sum(0) loss = -loss # I set negative because we don't have baseline for the moent # print ('The loss: ', loss) loss_data = loss.data.clone() #loss.detach() stats = onmt.Statistics(loss_data.item(), n_sentences=tgt.size()[1]) return loss, stats
def train(self, epoch, report_func=None): """ Train next epoch. Args: epoch(int): the epoch number report_func(fn): function for logging Returns: stats (:obj:`onmt.Statistics`): epoch loss statistics """ total_stats = Statistics() report_stats = Statistics() for i, batch in enumerate(self.train_iter): batch_size = batch.tgt.size(1) target_size = batch.tgt.size(0) # Truncated BPTT trunc_size = self.trunc_size if self.trunc_size else target_size dec_state = None src = onmt.io.make_features(batch, 'src', self.data_type) _, src_lengths = batch.src report_stats.n_src_words += src_lengths.sum() tgt = onmt.io.make_features(batch, 'tgt') if self.opt.encoder_model == 'Rev' and self.opt.decoder_model == 'Rev': if self.opt.use_reverse: tgt_lengths = torch.tensor([tgt.size(0)], device=tgt.device) self.model.zero_grad() loss, num_words, num_correct, enc_output_dict, dec_output_dict, mem_dict = self.model.forward_and_backward( src, tgt, src_lengths, tgt_lengths, self.weight, self.padding_idx) loss_data = loss.data.clone() batch_stats = onmt.Statistics(loss.item(), float(num_words), float(num_correct)) self.optim.step() total_stats.update(batch_stats) report_stats.update(batch_stats) else: self.model.zero_grad() tgt_lengths = torch.tensor([tgt.size(0)], device=tgt.device) loss, num_words, num_correct, attn, enc_output_dict, dec_output_dict, mem_dict = self.model( src, tgt, src_lengths, tgt_lengths, self.weight, self.padding_idx) loss_data = loss.data.clone() # batch_stats = self._stats(loss_data, outputs.data, batch.tgt[1:batch.tgt.size(0)].view(-1).data) # batch_stats = self._stats(loss_data, F.log_softmax(outputs.data), batch.tgt[1:batch.tgt.size(0)].view(-1).data) batch_stats = onmt.Statistics(loss.item(), float(num_words), float(num_correct)) loss.div(batch_size).backward() self.optim.step() total_stats.update(batch_stats) report_stats.update(batch_stats) else: self.model.zero_grad() model_output_tuple = self.model(src, tgt, src_lengths, dec_state) if len(model_output_tuple) == 5: outputs, attns, dec_state, enc_output_dict, dec_output_dict = model_output_tuple elif len(model_output_tuple) == 3: enc_output_dict = dec_output_dict = mem_dict = None outputs, attns, dec_state = model_output_tuple target = batch.tgt[1:batch.tgt.size(0)] scores = self.model.decoder.generator( outputs.view(-1, outputs.size(2))) gtruth = target.view(-1) loss = F.cross_entropy( scores, gtruth, weight=self.weight, size_average=False ) # , ignore_index=-100, reduce=None, reduction='elementwise_mean') loss_data = loss.data.clone() # batch_stats = self._stats(loss_data, scores.data, target.view(-1).data) batch_stats = self._stats(loss_data, F.log_softmax(scores.data), target.view(-1).data) loss.div(batch_size).backward() self.optim.step() total_stats.update(batch_stats) report_stats.update(batch_stats) if report_func is not None: report_stats = report_func(self, epoch, i, len( self.train_iter), total_stats.start_time, self.optim.lr, report_stats, enc_output_dict, dec_output_dict, mem_dict) return total_stats