Beispiel #1
0
def report_func(epoch, batch, num_batches,
                start_time, lr, report_stats):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    if batch % REPORT_EVERY == -1 % REPORT_EVERY:
        report_stats.output(epoch, batch+1, num_batches, start_time)
        report_stats = onmt.Statistics()

    return report_stats
Beispiel #2
0
    def sharded_compute_loss(self,
                             batch,
                             output,
                             attns,
                             cur_trunc,
                             trunc_size,
                             shard_size,
                             teacher_outputs=None):
        """
        Compute the loss in shards for efficiency.
        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        gen_state = make_gen_state(output, batch, attns, range_,
                                   self.copy_attn, teacher_outputs)
        for shard in shards(gen_state, shard_size):
            loss, stats = self.compute_loss(batch, **shard)
            loss.div(batch.batch_size).backward(retain_graph=True)
            batch_stats.update(stats)

        return batch_stats
Beispiel #3
0
    def _stats(self, loss, xent, kl, scores, target, sample_xents):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`Statistics` : statistics for this batch.
        """
        pred = scores.max(1)[1]
        non_padding = target.ne(self.padding_idx)
        num_correct = pred.eq(target) \
                          .masked_select(non_padding) \
                          .long().sum()
        return onmt.Statistics(loss.cpu().numpy(),
                               xent.cpu().numpy(),
                               kl.cpu().numpy(),
                               non_padding.long().sum().cpu().numpy(),
                               num_correct.cpu().numpy(),
                               sample_xents.cpu().numpy())
Beispiel #4
0
    def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size, output_t, attns_t):
        """
        Compute the loss in shards for efficiency.
        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        gen_state = make_gen_state(output, batch, attns, range_,
                                   self.copy_attn)
        gen_state_t = make_gen_state(output_t, batch, attns_t, range_,
                                   self.copy_attn)
        #loss, stats = self.compute_loss(batch, **gen_state)
        loss, stats = self.compute_loss(batch, gen_state['output'], gen_state['target'], gen_state_t['output']) 
        loss.div(batch.batch_size).backward()
        batch_stats.update(stats)
        #for shard in shards(gen_state, shard_size):
        #    loss, stats = self.compute_loss(batch, **shard)
        #    loss.div(batch.batch_size).backward()
        #    batch_stats.update(stats)

        return batch_stats
Beispiel #5
0
    def sharded_compute_loss(self, batch, output, attns, cur_trunc, trunc_size,
                             shard_size, normalization):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        shard_state = self._make_shard_state(batch, output, range_, attns)

        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)

            loss.div(normalization).backward()
            batch_stats.update(stats)

        return batch_stats
Beispiel #6
0
def report_func(epoch, batch, num_batches, start_time, lr, report_stats, opt):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch + 1, num_batches, start_time)
        if opt.exp_host:
            report_stats.log("progress", experiment, lr)
        report_stats = onmt.Statistics()

    return report_stats
Beispiel #7
0
def report_func(epoch, batch, num_batches, progress_step, start_time, lr,
                report_stats):
    '''
    This is the user-defined batch-level training progress report function.
    :param epoch(int): current epoch count.
    :param batch(int): current batch count.
    :param num_batches(int): total number of batches.
    :param progress_step(int): the progress time.
    :param start_time(float): last report time.
    :param lr(float): current learning rate.
    :param report_stats(Statistics): old Statistics instance.
    :return:
    '''
    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch + 1, num_batches, start_time)
        if opt.tensorboard:
            # Log the progress using the number of batches on the x-axis.
            report_stats.log_tensorboard('progress', writer, lr, progress_step)

        report_stats = onmt.Statistics()

    return report_stats
Beispiel #8
0
    def sharded_compute_loss(self, batch, output, kappa_output, attns,
                             encoder_output, kappa_encoder_output,
                             decoder_output_wod, kappa_decoder_output_wod,
                             cur_trunc, trunc_size, shard_size):
        """
        Compute the loss in shards for efficiency.
        """
        batch_stats = onmt.Statistics()

        range_ = (cur_trunc, cur_trunc + trunc_size)

        shard_state = self.make_shard_state(batch, range_, output,
                                            kappa_output, encoder_output,
                                            kappa_encoder_output,
                                            decoder_output_wod,
                                            kappa_decoder_output_wod, attns)
        #print(shard_state)
        for shard in shards(shard_state, shard_size):
            loss, stats = self.compute_loss(batch, **shard)
            loss.div(batch.batch_size).backward()
            batch_stats.update(stats)

        return batch_stats
Beispiel #9
0
    def sharded_compute_loss(self,
                             batch,
                             output,
                             sample_outputs,
                             attns,
                             sample_attns,
                             sample_batch_tgt,
                             sample_batch_alignment,
                             cur_trunc,
                             trunc_size,
                             shard_size,
                             normalization,
                             backward=True,
                             rewards=None):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """
        assert rewards is not None

        def shard_compute(batch, ml_shard_state, rl_shard_state, shard_size,
                          batch_stats):

            #         input()
            return loss_list, batch_stats

        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        ml_shard_state = self._make_shard_state(batch.tgt, batch.alignment,
                                                output, range_, attns)
        rl_shard_state = self._make_shard_state(sample_batch_tgt,
                                                sample_batch_alignment,
                                                sample_outputs, range_,
                                                sample_attns, rewards)

        shards = onmt.Loss.shards
        for ml_shard, rl_shard in zip(shards(ml_shard_state, shard_size),
                                      shards(rl_shard_state, shard_size)):
            #             print("Loss, line:123", shard)
            ml_loss, stats = self.ml_loss_compute._compute_loss(
                batch, **ml_shard)
            # be able to use same batch becuase subprocess doesn't use batch.tgt or batch.alignment
            rl_loss, stats = self.rl_loss_compute._compute_loss(
                batch, **rl_shard)
            if backward:
                loss = (1 - self.apply_factor) * ml_loss.div(
                    normalization) + self.apply_factor * rl_loss.div(
                        normalization)
                loss.backward()
            batch_stats.update(stats)

        return batch_stats
Beispiel #10
0
def optimize_quantization_points(modelToQuantize,
                                 train_loader,
                                 test_loader,
                                 options,
                                 optim=None,
                                 numPointsPerTensor=16,
                                 assignBitsAutomatically=False,
                                 use_distillation_loss=False,
                                 bucket_size=None):

    print('Preparing training - pre processing tensors')

    if options is None: options = onmt.standard_options.stdOptions
    if not isinstance(options, dict):
        options = mhf.convertToDictionary(options)
    options = handle_options(options)
    options = mhf.convertToNamedTuple(options)

    modelToQuantize.eval()
    quantizedModel = copy.deepcopy(modelToQuantize)

    fields = train_loader.dataset.fields
    train_loss = make_loss_compute(quantizedModel, fields["tgt"].vocab,
                                   train_loader.dataset, options.copy_attn,
                                   options.copy_attn_force)
    valid_loss = make_loss_compute(quantizedModel, fields["tgt"].vocab,
                                   test_loader.dataset, options.copy_attn,
                                   options.copy_attn_force)
    trunc_size = options.truncated_decoder  # Badly named...
    shard_size = options.max_generator_batches

    numTensorsNetwork = sum(1 for _ in quantizedModel.parameters())
    if isinstance(numPointsPerTensor, int):
        numPointsPerTensor = [numPointsPerTensor] * numTensorsNetwork
    if len(numPointsPerTensor) != numTensorsNetwork:
        raise ValueError(
            'numPointsPerTensor must be equal to the number of tensor in the network'
        )

    scalingFunction = quantization.ScalingFunction(type_scaling='linear',
                                                   max_element=False,
                                                   subtract_mean=False,
                                                   modify_in_place=False,
                                                   bucket_size=bucket_size)

    quantizedModel.zero_grad()
    dummy_optim = create_optimizer(
        quantizedModel, options)  #dummy optim, just to pass to trainer
    if assignBitsAutomatically:
        trainer = thf.MyTrainer(quantizedModel, train_loader, test_loader,
                                train_loss, valid_loss, dummy_optim,
                                trunc_size, shard_size)
        batch = next(iter(train_loader))
        quantizedModel.zero_grad()
        trainer.forward_and_backward(0, batch, 0, onmt.Statistics(), None)
        fisherInformation = []
        for p in quantizedModel.parameters():
            fisherInformation.append(p.grad.data.norm())
        numPointsPerTensor = qhf.assign_bits_automatically(fisherInformation,
                                                           numPointsPerTensor,
                                                           input_is_point=True)
        quantizedModel.zero_grad()
        del trainer
        del optim

    # initialize the points using the percentile function so as to make them all usable
    pointsPerTensor = []
    for idx, p in enumerate(quantizedModel.parameters()):
        initial_points = qhf.initialize_quantization_points(
            p.data, scalingFunction, numPointsPerTensor[idx])
        initial_points = Variable(initial_points, requires_grad=True)
        # do a dummy backprop so that the grad attribute is initialized. We need this because we call
        # the .backward() function manually later on (since pytorch can't assign variables to model
        # parameters)
        initial_points.sum().backward()
        pointsPerTensor.append(initial_points)

    optionsOpt = copy.deepcopy(mhf.convertToDictionary(options))
    optimizer = create_optimizer(pointsPerTensor,
                                 mhf.convertToNamedTuple(optionsOpt))
    trainer = thf.MyTrainer(quantizedModel, train_loader, test_loader,
                            train_loss, valid_loss, dummy_optim, trunc_size,
                            shard_size)
    perplexity_epochs = []

    quantizationFunctions = []
    for idx, p in enumerate(modelToQuantize.parameters()):
        #efficient version of nonUniformQuantization
        quant_fun = quantization.nonUniformQuantization_variable(
            max_element=False,
            subtract_mean=False,
            modify_in_place=False,
            bucket_size=bucket_size,
            pre_process_tensors=True,
            tensor=p.data)

        quantizationFunctions.append(quant_fun)

    print('Pre processing done, training started')

    for epoch in range(options.start_epoch, options.epochs + 1):
        train_stats = onmt.Statistics()
        quantizedModel.train()
        for idx_batch, batch in enumerate(train_loader):

            #zero the gradient
            quantizedModel.zero_grad()

            # quantize the weights
            for idx, p_quantized in enumerate(quantizedModel.parameters()):
                #I am using the efficient version of nonUniformQuantization. The tensors (that don't change across
                #iterations) are saved inside the quantization function, and we only need to pass the quantization
                #points
                p_quantized.data = quantizationFunctions[idx].forward(
                    None, pointsPerTensor[idx].data)

            trainer.forward_and_backward(idx_batch, batch, epoch, train_stats,
                                         report_func, use_distillation_loss,
                                         modelToQuantize)

            # now get the gradient of the pointsPerTensor
            for idx, p in enumerate(quantizedModel.parameters()):
                pointsPerTensor[idx].grad.data = quantizationFunctions[
                    idx].backward(p.grad.data)[1]

            optimizer.step()

            # after optimzer.step() we need to make sure that the points are still sorted
            for points in pointsPerTensor:
                points.data = torch.sort(points.data)[0]

        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())
        perplexity_epochs.append(valid_stats.ppl())

        # 3. Update the learning rate
        optimizer.updateLearningRate(valid_stats.ppl(), epoch)

    informationDict = {}
    informationDict['perplexity'] = perplexity_epochs
    informationDict[
        'numEpochsTrained'] = options.epochs + 1 - options.start_epoch
    return pointsPerTensor, informationDict
Beispiel #11
0
def train_model(model,
                train_loader,
                test_loader,
                plot_path,
                optim=None,
                options=None,
                stochasticRounding=False,
                quantizeWeights=False,
                numBits=8,
                maxElementAllowedForQuantization=False,
                bucket_size=None,
                subtractMeanInQuantization=False,
                quantizationFunctionToUse='uniformLinearScaling',
                backprop_quantization_style='none',
                num_estimate_quant_grad=1,
                use_distillation_loss=False,
                teacher_model=None,
                quantize_first_and_last_layer=True):

    if options is None:
        options = copy.deepcopy(onmt.standard_options.stdOptions)
    if not isinstance(options, dict):
        options = mhf.convertToDictionary(options)
    options = handle_options(options)
    options = mhf.convertToNamedTuple(options)

    if optim is None:
        optim = create_optimizer(model, options)

    if use_distillation_loss is True and teacher_model is None:
        raise ValueError(
            'If training with distilled word level, we need teacher_model to be passed'
        )

    if teacher_model is not None:
        teacher_model.eval()

    step_since_last_grad_quant_estimation = 0
    num_param_model = sum(1 for _ in model.parameters())
    if quantizeWeights:
        quantizationFunctionToUse = quantizationFunctionToUse.lower()
        if quantizationFunctionToUse == 'uniformAbsMaxScaling'.lower():
            s = 2**(numBits - 1)
            type_of_scaling = 'absmax'
        elif quantizationFunctionToUse == 'uniformLinearScaling'.lower():
            s = 2**numBits
            type_of_scaling = 'linear'
        else:
            raise ValueError(
                'The specified quantization function is not present')

        if backprop_quantization_style is None or backprop_quantization_style in (
                'none', 'truncated'):
            quantizeFunctions = lambda x: quantization.uniformQuantization(
                x,
                s,
                type_of_scaling=type_of_scaling,
                stochastic_rounding=stochasticRounding,
                max_element=maxElementAllowedForQuantization,
                subtract_mean=subtractMeanInQuantization,
                modify_in_place=False,
                bucket_size=bucket_size)[0]

        elif backprop_quantization_style == 'complicated':
            quantizeFunctions = [quantization.uniformQuantization_variable(s, type_of_scaling=type_of_scaling,
                                                    stochastic_rounding=stochasticRounding,
                                                    max_element=maxElementAllowedForQuantization,
                                                    subtract_mean=subtractMeanInQuantization,
                                                    modify_in_place=False, bucket_size=bucket_size) \
                                 for _ in model.parameters()]
        else:
            raise ValueError(
                'The specified backprop_quantization_style not recognized')

    fields = train_loader.dataset.fields
    # Collect features.
    src_features = collect_features(train_loader.dataset, fields)
    for j, feat in enumerate(src_features):
        print(' * src feature %d size = %d' % (j, len(fields[feat].vocab)))

    train_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_loader.dataset, options.copy_attn,
                                   options.copy_attn_force,
                                   use_distillation_loss, teacher_model)
    #for validation we don't use distilled loss; it would screw up the perplexity computation
    valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   test_loader.dataset, options.copy_attn,
                                   options.copy_attn_force)

    trunc_size = None  #options.truncated_decoder  # Badly named...
    shard_size = options.max_generator_batches

    trn_writer = tbx.SummaryWriter(plot_path + '_output/train')
    tst_writer = tbx.SummaryWriter(plot_path + '_output/test')

    trainer = thf.MyTrainer(model, train_loader, test_loader, train_loss,
                            valid_loss, optim, trunc_size, shard_size)

    perplexity_epochs = []
    for epoch in range(options.start_epoch, options.epochs + 1):
        MAX_Memory = 0
        train_stats = onmt.Statistics()
        model.train()
        for idx_batch, batch in enumerate(train_loader):

            model.zero_grad()

            if quantizeWeights:
                if step_since_last_grad_quant_estimation >= num_estimate_quant_grad:
                    # we save them because we only want to quantize weights to compute gradients,
                    # but keep using non-quantized weights during the algorithm
                    model_state_dict = model.state_dict()
                    for idx, p in enumerate(model.parameters()):
                        if quantize_first_and_last_layer is False:
                            if idx == 0 or idx == num_param_model - 1:
                                continue
                        if backprop_quantization_style == 'truncated':
                            p.data.clamp_(
                                -1, 1
                            )  # TODO: Is this necessary? Clamping the weights?
                        if backprop_quantization_style in ('none',
                                                           'truncated'):
                            p.data = quantizeFunctions(p.data)
                        elif backprop_quantization_style == 'complicated':
                            p.data = quantizeFunctions[idx].forward(p.data)
                        else:
                            raise ValueError
            trainer.forward_and_backward(idx_batch, batch, epoch, train_stats,
                                         report_func, use_distillation_loss,
                                         teacher_model)

            if quantizeWeights:
                if step_since_last_grad_quant_estimation >= num_estimate_quant_grad:
                    model.load_state_dict(model_state_dict)
                    del model_state_dict  # free memory

                    if backprop_quantization_style in ('truncated',
                                                       'complicated'):
                        for idx, p in enumerate(model.parameters()):
                            if quantize_first_and_last_layer is False:
                                if idx == 0 or idx == num_param_model - 1:
                                    continue
                            #Now some sort of backward. For the none style, we don't do anything.
                            #for the truncated style, we just need to truncate the grad weights
                            #as per the paper here: https://arxiv.org/pdf/1609.07061.pdf
                            #Complicated is my derivation, but unsure whether to use it or not
                            if backprop_quantization_style == 'truncated':
                                p.grad.data[p.data.abs() > 1] = 0
                            elif backprop_quantization_style == 'complicated':
                                p.grad.data = quantizeFunctions[idx].backward(
                                    p.grad.data)

            #update parameters after every batch
            trainer.optim.step()

            if step_since_last_grad_quant_estimation >= num_estimate_quant_grad:
                step_since_last_grad_quant_estimation = 0

            step_since_last_grad_quant_estimation += 1

        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        trn_writer.add_scalar('ppl', train_stats.ppl(), epoch + 1)
        trn_writer.add_scalar('acc', train_stats.accuracy(), epoch + 1)

        # 2. Validate on the validation set.
        MAX_Memory = max(MAX_Memory, torch.cuda.max_memory_allocated())
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())
        print('Max allocated memory: {:2f}MB'.format(MAX_Memory / (1024**2)))
        perplexity_epochs.append(valid_stats.ppl())

        tst_writer.add_scalar('ppl', valid_stats.ppl(), epoch + 1)
        tst_writer.add_scalar('acc', valid_stats.accuracy(), epoch + 1)

        # 3. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

    if quantizeWeights:
        for idx, p in enumerate(model.parameters()):
            if backprop_quantization_style == 'truncated':
                p.data.clamp_(
                    -1, 1)  # TODO: Is this necessary? Clamping the weights?
            if backprop_quantization_style in ('none', 'truncated'):
                p.data = quantizeFunctions(p.data)
            elif backprop_quantization_style == 'complicated':
                p.data = quantizeFunctions[idx].forward(p.data)
                del quantizeFunctions[idx].saved_for_backward
                quantizeFunctions[idx].saved_for_backward = None  # free memory
            else:
                raise ValueError

    informationDict = {}
    informationDict['perplexity'] = perplexity_epochs
    informationDict[
        'numEpochsTrained'] = options.epochs + 1 - options.start_epoch
    return model, informationDict
Beispiel #12
0
def report_func(epoch, batch, num_batches, progress_step, start_time, lr, report_stats):
    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch + 1, num_batches, start_time)
        report_stats = onmt.Statistics()
    return report_stats
Beispiel #13
0
def report_func(trnr, epoch, batch, num_batches, start_time, lr, report_stats, enc_output_dict, dec_output_dict, mem_dict):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    global iteration_log_loss_file
    global mem_log_file

    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch+1, num_batches, start_time)

        if mem_dict:
            if opt.separate_buffers:
                enc_used_bits = mem_dict['enc_used_bits']
                dec_used_bits = mem_dict['dec_used_bits']
                enc_optimal_bits = mem_dict['enc_optimal_bits']
                dec_optimal_bits = mem_dict['dec_optimal_bits']
                enc_normal_bits = mem_dict['enc_normal_bits']
                dec_normal_bits = mem_dict['dec_normal_bits']

                enc_actual_ratio = enc_normal_bits / enc_used_bits
                enc_optimal_ratio = enc_normal_bits / enc_optimal_bits
                dec_actual_ratio = dec_normal_bits / dec_used_bits
                dec_optimal_ratio = dec_normal_bits / dec_optimal_bits

                print("Enc actual memory ratio {}".format(enc_actual_ratio))
                print("Enc optimal memory ratio {}".format(enc_optimal_ratio))
                print("Dec actual memory ratio {}".format(dec_actual_ratio))
                print("Dec optimal memory ratio {}".format(dec_optimal_ratio))

                mem_log_file.write('{} {} {} {} {} {}\n'.format(
                                    epoch, batch, enc_actual_ratio, enc_optimal_ratio, dec_actual_ratio, dec_optimal_ratio))
                mem_log_file.flush()
            else:
                used_bits = mem_dict['used_bits']
                optimal_bits = mem_dict['optimal_bits']
                normal_bits = mem_dict['normal_bits']

                actual_ratio = normal_bits / used_bits
                optimal_ratio = normal_bits / optimal_bits

                print("Actual memory ratio {}".format(actual_ratio))
                print("Optimal memory ratio {}".format(optimal_ratio))

                mem_log_file.write('{} {} {} {}\n'.format(
                                    epoch, batch, actual_ratio, optimal_ratio))
                mem_log_file.flush()


        if not opt.no_log_during_epoch:
            valid_stats = trnr.validate()

            iteration_log_loss_file.write('{} {} {} {} {} {}\n'.format(
                                           epoch, batch, report_stats.accuracy(), report_stats.ppl(), valid_stats.accuracy(),
                                           valid_stats.ppl()))
            iteration_log_loss_file.flush()

        report_stats = onmt.Statistics()

    return report_stats
Beispiel #14
0
    def _compute_loss(self, batch, ret, doc_index):

        #RET obtain the hypothesis
        top_hyp = ret['predictions']
        top_probabilities = ret['out_prob']
        #AND ALSO THE REFERENCE
        tgt = batch.tgt
        batch_num = tgt.size()[1]

        # print ('top_hyp ',len(top_hyp))
        # print ('top_probabilities ', len(top_probabilities))
        # print ('tgt ',tgt.size())
        if self.doc_level:
            sentences_pred, prob_per_word = self.words_probs_from_preds(
                top_hyp,
                top_probabilities,
                doc_index=doc_index,
                batch_num=batch_num)
            sentences_gt = self.obtain_words_from_tgt(tgt, doc_index=doc_index)
            # print (doc_index)
            # print (prob_per_word.size())
            # print (len(sentences_pred[0][0]))
            # print (len(sentences_pred[0][1]))
            # print (len(sentences_gt[0]))

            dl_rewards = 0

            if self.bleu_doc:
                bleu_doc_scores = self.BLEU_score(sentences_pred, sentences_gt)
                dl_rewards += bleu_doc_scores
            if self.LC_doc:
                LC_scores = self.LC_scores(sentences_pred)
                dl_rewards += LC_scores
            if self.COH_doc:
                coher_scores = self.coher_scores(sentences_pred)
                dl_rewards += coher_scores

            loss_doc = self.RISK_loss(prob_per_word, dl_rewards)
            loss_doc = loss_doc.sum(0)

        if self.bleu_sen:
            sentences_pred, prob_per_word = self.words_probs_from_preds(
                top_hyp, top_probabilities)
            # print (prob_per_word.size())
            # print ('i am printing')
            sentences_gt = self.obtain_words_from_tgt(tgt)
            sl_rewards = self.BLEU_score(sentences_pred, sentences_gt)
            loss_sen = self.RISK_loss(prob_per_word, sl_rewards)
            loss_sen = loss_sen.sum(0)

        if self.doc_level and self.bleu_sen:
            loss = loss_doc + loss_sen
        elif self.doc_level:
            loss = loss_doc
        else:
            loss = loss_sen

        # #NEED TO COMPUTE THE LOSS
        # loss = bleu_scores*prob_per_word
        # loss =loss.sum(0)
        loss = -loss  # I set negative because we don't have baseline for the moent
        # print ('The loss: ', loss)

        loss_data = loss.data.clone()

        #loss.detach()

        stats = onmt.Statistics(loss_data.item(), n_sentences=tgt.size()[1])
        return loss, stats
Beispiel #15
0
    def train(self, epoch, report_func=None):
        """ Train next epoch.
        Args:
            epoch(int): the epoch number
            report_func(fn): function for logging

        Returns:
            stats (:obj:`onmt.Statistics`): epoch loss statistics
        """
        total_stats = Statistics()
        report_stats = Statistics()

        for i, batch in enumerate(self.train_iter):
            batch_size = batch.tgt.size(1)
            target_size = batch.tgt.size(0)
            # Truncated BPTT
            trunc_size = self.trunc_size if self.trunc_size else target_size

            dec_state = None
            src = onmt.io.make_features(batch, 'src', self.data_type)

            _, src_lengths = batch.src
            report_stats.n_src_words += src_lengths.sum()

            tgt = onmt.io.make_features(batch, 'tgt')

            if self.opt.encoder_model == 'Rev' and self.opt.decoder_model == 'Rev':
                if self.opt.use_reverse:
                    tgt_lengths = torch.tensor([tgt.size(0)],
                                               device=tgt.device)

                    self.model.zero_grad()
                    loss, num_words, num_correct, enc_output_dict, dec_output_dict, mem_dict = self.model.forward_and_backward(
                        src, tgt, src_lengths, tgt_lengths, self.weight,
                        self.padding_idx)

                    loss_data = loss.data.clone()
                    batch_stats = onmt.Statistics(loss.item(),
                                                  float(num_words),
                                                  float(num_correct))

                    self.optim.step()

                    total_stats.update(batch_stats)
                    report_stats.update(batch_stats)
                else:
                    self.model.zero_grad()
                    tgt_lengths = torch.tensor([tgt.size(0)],
                                               device=tgt.device)
                    loss, num_words, num_correct, attn, enc_output_dict, dec_output_dict, mem_dict = self.model(
                        src, tgt, src_lengths, tgt_lengths, self.weight,
                        self.padding_idx)
                    loss_data = loss.data.clone()
                    # batch_stats = self._stats(loss_data, outputs.data, batch.tgt[1:batch.tgt.size(0)].view(-1).data)
                    # batch_stats = self._stats(loss_data, F.log_softmax(outputs.data), batch.tgt[1:batch.tgt.size(0)].view(-1).data)
                    batch_stats = onmt.Statistics(loss.item(),
                                                  float(num_words),
                                                  float(num_correct))

                    loss.div(batch_size).backward()
                    self.optim.step()

                    total_stats.update(batch_stats)
                    report_stats.update(batch_stats)
            else:
                self.model.zero_grad()
                model_output_tuple = self.model(src, tgt, src_lengths,
                                                dec_state)

                if len(model_output_tuple) == 5:
                    outputs, attns, dec_state, enc_output_dict, dec_output_dict = model_output_tuple
                elif len(model_output_tuple) == 3:
                    enc_output_dict = dec_output_dict = mem_dict = None
                    outputs, attns, dec_state = model_output_tuple

                target = batch.tgt[1:batch.tgt.size(0)]
                scores = self.model.decoder.generator(
                    outputs.view(-1, outputs.size(2)))
                gtruth = target.view(-1)
                loss = F.cross_entropy(
                    scores, gtruth, weight=self.weight, size_average=False
                )  # , ignore_index=-100, reduce=None, reduction='elementwise_mean')
                loss_data = loss.data.clone()
                # batch_stats = self._stats(loss_data, scores.data, target.view(-1).data)
                batch_stats = self._stats(loss_data,
                                          F.log_softmax(scores.data),
                                          target.view(-1).data)

                loss.div(batch_size).backward()
                self.optim.step()

                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

            if report_func is not None:
                report_stats = report_func(self, epoch, i, len(
                    self.train_iter), total_stats.start_time, self.optim.lr,
                                           report_stats, enc_output_dict,
                                           dec_output_dict, mem_dict)

        return total_stats