Beispiel #1
0
    def _reranker_gradient_accumulation(self, true_batchs, normalization,
                                        total_stats, report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        assert len(true_batchs) == 1
        # for batch in true_batchs:
        batch = true_batchs[0]
        src = inputters.make_features(
            batch, 'src',
            self.data_type)  # [src_len, batch_size, num_features]
        _, src_lengths = batch.src

        tgt = inputters.make_features(
            batch, 'tgt',
            self.data_type)  # [tgt_len, batch_size, num_features]
        _, tgt_lengths = batch.tgt

        # 1. F-prop all.
        if self.grad_accum_count == 1:
            self.model.zero_grad()

        logits, probs = self.model(src, tgt, src_lengths, tgt_lengths)

        batch_stats = self.train_loss.sharded_compute_loss(
            batch, logits, probs, None, normalization)
        total_stats.update(batch_stats)
        report_stats.update(batch_stats)

        assert self.n_gpu == 1
        # 2. Update the parameters and statistics.
        self.optim.step(self.cur_valid_ppl)
        report_stats.lr_rate = self.optim.learning_rate
        report_stats.total_norm += self.optim.total_norm
Beispiel #2
0
    def _run_target(self, batch, data):
        data_type = data.data_type
        if data_type == 'text':
            _, src_lengths = batch.src
        else:
            src_lengths = None
        src = inputters.make_features(batch, 'src', data_type)
        tgt_in = inputters.make_features(batch, 'tgt')[:-1]

        #  (1) run the encoder on the src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = \
            self.model.decoder.init_decoder_state(src, memory_bank, enc_states)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        tt = torch.cuda if self.cuda else torch
        gold_scores = tt.FloatTensor(batch.batch_size).fill_(0)
        dec_out, _, _ = self.model.decoder(
            tgt_in, memory_bank, dec_states, memory_lengths=src_lengths)

        tgt_pad = self.fields["tgt"].vocab.stoi[inputters.PAD_WORD]
        for dec, tgt in zip(dec_out, batch.tgt[1:].data):
            # Log prob of each word.
            out = self.model.generator.forward(dec)
            tgt = tgt.unsqueeze(1)
            scores = out.data.gather(1, tgt)
            scores.masked_fill_(tgt.eq(tgt_pad), 0)
            gold_scores += scores.view(-1)
        return gold_scores
Beispiel #3
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)
            # Truncated BPTT
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            dec_state = None
            src = inputters.make_features(
                batch, 'src',
                self.data_type)  # [src_len, batch_size, num_features]
            if self.data_type == 'text':
                _, src_lengths = batch.src
                report_stats.n_src_words += src_lengths.sum().item()
            else:
                src_lengths = None

            tgt_outer = inputters.make_features(batch, 'tgt')

            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.grad_accum_count == 1:
                    self.model.zero_grad()
                outputs, attns, dec_state = \
                    self.model(src, tgt, src_lengths, dec_state)

                # 3. Compute loss in shards for memory efficiency.
                batch_stats = self.train_loss.sharded_compute_loss(
                    batch, outputs, attns, j, trunc_size, self.shard_size,
                    normalization)
                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

                # If truncated, don't backprop fully.
                if dec_state is not None:
                    dec_state.detach()

        # 3.bis Multi GPU gradient gather
        if self.n_gpu > 1:
            grads = [
                p.grad.data for p in self.model.parameters()
                if p.requires_grad and p.grad is not None
            ]
            onmt.utils.distributed.all_reduce_and_rescale_tensors(
                grads, float(1))

        # 4. Update the parameters and statistics.
        # changed for KE_KG
        self.optim.step(self.cur_valid_ppl)
        report_stats.lr_rate = self.optim.learning_rate
        report_stats.total_norm = self.optim.total_norm
Beispiel #4
0
    def _run_encoder(self, batch, data_type):
        src = inputters.make_features(batch, 'src', data_type)
        src_lengths = None
        tgt = inputters.make_features(batch, 'tgt', data_type)
        if data_type == 'text':
            _, src_lengths = batch.src
        elif data_type == 'audio':
            src_lengths = batch.src_lengths
        #print(src.shape)
        #print(src[:,1,:])
        #print(tgt.shape)
        #print(tgt[:,1,:])
        tgt = tgt[1:, :, :]
        tgt[tgt == 3] = 1
        enc_states, memory_bank, src_lengths = self.model.encoder(
            src, src_lengths)
        _, tgt_memory_bank, _ = self.model.encoder(tgt, None)

        if src_lengths is None:
            assert not isinstance(memory_bank, tuple), \
                'Ensemble decoding only supported for text data'
            src_lengths = torch.Tensor(batch.batch_size) \
                               .type_as(memory_bank) \
                               .long() \
                               .fill_(memory_bank.size(0))
        return src, enc_states, memory_bank, src_lengths
Beispiel #5
0
    def _end2end_gradient_accumulation(self, true_batchs, normalization,
                                       total_stats, report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            # 1. prepare the batch
            src = inputters.make_features(
                batch, 'src',
                self.data_type)  # [src_len, batch_size, num_features]
            _, src_lengths = batch.src

            tgt = inputters.make_features(batch, 'tgt')

            # 2. F-prop all but generator.
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            dec_outputs, attns, dec_state, sel_outputs, sel_probs =\
                self.model(src, tgt, src_lengths, gt_probs=batch.key_indicators[0])

            # 3. Compute loss
            # (self, batch, sel_outputs, sel_probs, dec_outputs, attns, normalization)
            batch_stats = self.train_loss.sharded_compute_loss(
                batch, sel_outputs, sel_probs, dec_outputs, attns,
                normalization)
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

        # assert self.n_gpu == 1
        # 4. Update the parameters and statistics.
        self.optim.step(self.cur_valid_ppl)
        report_stats.lr_rate = self.optim.learning_rate
        report_stats.total_norm = self.optim.total_norm
Beispiel #6
0
    def _run_target(self, batch, data):
        data_type = data.data_type
        if data_type == 'text':
            _, src_lengths = batch.src
        else:
            src_lengths = None
        src = inputters.make_features(batch, 'src', data_type)
        tgt_in = inputters.make_features(batch, 'tgt')[:-1]

        #  (1) run the encoder on the src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = \
            self.model.decoder.init_decoder_state(src, memory_bank, enc_states)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        correct = defaultdict(int)
        total = defaultdict(int)

        dec_out, _, _ = self.model.decoder(tgt_in,
                                           memory_bank,
                                           dec_states,
                                           memory_lengths=src_lengths)

        for dec, tgt in zip(dec_out, batch.tgt[1:].data):
            # Log prob of each word.
            out = self.model.generator.forward(dec)
            for _x, _y in zip(out.max(1)[1], tgt):
                x = _x.item()
                y = _y.item()
                total[y] += 1
                if x == y:
                    correct[x] += 1

        return correct, total
Beispiel #7
0
    def step(self, batch, normalization, total_stats, report_stats):
        dec_state = None
        src = inputters.make_features(batch, 'src', 'text')
        _, src_lengths = batch.src
        report_stats.n_src_words += src_lengths.sum().item()

        tgt = inputters.make_features(batch, 'tgt')

        # 2. F-prop all AND generator.
        self.model.zero_grad()
        outputs, attns, dec_state = self.model(src, tgt, src_lengths, dec_state)
        scores = self.model.generator(outputs.view(-1, outputs.size(2)))

        # 3. Compute loss in shards for memory efficiency.
        if self.pair_size != 0:
            target = batch.tgt[self.pair_size:batch.tgt.size(0)]
        else:
            target = batch.tgt[1:batch.tgt.size(0)]
        gtruth = target.view(-1)
        loss = self.creterion(scores, gtruth)
        loss.div(float(normalization)).backward()
        batch_stats = self._stats(loss.data.clone(), scores.data, target.view(-1).data)

        total_stats.update(batch_stats)
        report_stats.update(batch_stats)

        self.optim.step()
Beispiel #8
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = inputters.make_features(batch, 'tgt')

            # F-prop through the model.
            outputs, attns, _ = self.model(src, tgt, src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #9
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text' and not self.model.decoder.decoder_type.startswith('vecdif'):
                _, src_lengths = batch.src
            elif self.data_type == 'audio':
                src_lengths = batch.src_lengths
            else:
                src_lengths = src.size()

            tgt = inputters.make_features(batch, 'tgt', self.data_type)

            # F-prop through the model.
            if self.model.decoder.decoder_type.startswith('vecdif'):
                if self.data_type == 'text':
                    src = torch.squeeze(src, 2)
                    tgt = torch.squeeze(tgt, 2)
                    # if self.model.decoder.decoder_type=="vecdif_multi":
                    all_outputs = []
                    if self.n_gpu > 0:
                        covered_target = torch.zeros((src_lengths[0], 512), dtype=torch.float).cuda()
                    else:
                        covered_target = torch.zeros((src_lengths[0], 512), dtype=torch.float)
                    src_representation = self.assemble_src_representation(src)
                    for target_id in range(0, tgt.size()[1]):
                        outputs, scores, covered_target = self.model(-1, src, None,covered_target,src_representation, target_id)
                        all_outputs.append(outputs.detach())
                else:
                    src_representation = self.assemble_src_representation(src)
                    outputs, scores, _= self.model(-1, src, None, source_vector=src_representation )

                attns = None
            else:
                outputs, attns, _ = self.model(src, tgt, src_lengths)

            # Compute loss.
            if self.model.decoder.decoder_type == "vecdif_multi":
                batch_stats = self.valid_loss.monolithic_compute_loss_multivec(
                    batch, all_outputs)
            else:
                batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #10
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        with torch.no_grad():
            stats = onmt.utils.Statistics()

            for batch in valid_iter:
                src, src_lengths = inputters.make_features(batch, 'src')
                tgt, _ = inputters.make_features(batch, 'tgt')

                # F-prop through the model.
                outputs, attns = self.model(src, tgt, src_lengths)

                # Compute loss.
                batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

                # Update statistics.
                stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #11
0
    def _meta_gradient_grads(self, batch, meta_model, meta_train_loss):

        src = inputters.make_features(batch, 'src', self.data_type)
        if self.data_type == 'text':
            _, src_lengths = batch.src
        else:
            src_lengths = None

        tgt_outer = inputters.make_features(batch, 'tgt')
        tgt = tgt_outer

        meta_model.zero_grad()
        outputs, attns = meta_model(src, tgt, src_lengths)

        _, loss = meta_train_loss.monolithic_compute_raw_loss(
            batch, outputs, attns)

        meta_loss_train = torch.sum(loss)

        meta_model.zero_grad()
        grads = torch.autograd.grad(meta_loss_train, (meta_model.params()),
                                    create_graph=True)

        if meta_model.decoder.state is not None:
            meta_model.decoder.detach_state()

        return grads
Beispiel #12
0
    def _meta_gradient_weighted_train(self,
                                      batch,
                                      weights,
                                      report_stats,
                                      total_stats,
                                      task_id,
                                      zero_gradients=True):

        src = inputters.make_features(batch, 'src', self.data_type)
        if self.data_type == 'text':
            _, src_lengths = batch.src
            report_stats.n_src_words += src_lengths.sum().item()
        elif self.data_type == 'audio':
            src_lengths = batch.src_lengths
        else:
            src_lengths = None

        tgt_outer = inputters.make_features(batch, 'tgt')
        tgt = tgt_outer
        if zero_gradients:
            self.models_list[task_id].zero_grad()
        outputs, attns = self.models_list[task_id](src, tgt, src_lengths)

        batch_stats, loss = self.train_loss_list[
            task_id].monolithic_compute_raw_loss(batch, outputs, attns)

        weighted_loss = torch.sum(loss * weights)
        weighted_loss.backward()

        total_stats.update(batch_stats)
        report_stats.update(batch_stats)

        if self.models_list[task_id].decoder.state is not None:
            self.models_list[task_id].decoder.detach_state()
Beispiel #13
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)
            # Truncated BPTT
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            dec_state = None
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
                report_stats.n_src_words += src_lengths.sum().item()
            else:
                src_lengths = None

            tgt_outer = inputters.make_features(batch, 'tgt')

            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.grad_accum_count == 1:
                    self.model.zero_grad()
                    enc_outputs, outputs, attns, dec_state = \
                        self.model(src, tgt, src_lengths, dec_state)

                # 3. Compute loss for G in shards for memory efficiency.
                batch_stats_G = self.train_loss.sharded_compute_loss(
                    batch, enc_outputs, outputs, attns, j, trunc_size,
                    self.shard_size, normalization)
                total_stats.update(batch_stats_G)
                report_stats.update(batch_stats_G)
                self.optim.step()  # update G

                # 4. Compute loss for D in shards for memory efficiency.
                tgt_lengths = torch.ones(tgt.size(1)).long() * tgt.size(0)
                tgt_lengths = tgt_lengths.to(tgt.device)
                enc_final, memory_bank = self.model.encoder(tgt, tgt_lengths)
                self.model.zero_grad()
                # batch_stats_D =
                self.train_loss.sharded_compute_loss_discriminator(
                    batch, enc_outputs, outputs, attns, j, trunc_size,
                    self.shard_size, memory_bank)

                # TODO: update stats for D
                # total_stats.update(batch_stats_D)
                # report_stats.update(batch_stats_D)
                self.optim_D.step()  # update D

                # If truncated, don't backprop fully.
                if dec_state is not None:
                    dec_state.detach()
Beispiel #14
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:

            #TODO: access wals features for the language pairs from batch variable.
            #src_language = batch.examples[0].src_language
            #tgt_language = batch.examples[0].tgt_language

            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = inputters.make_features(batch, 'tgt')

            def process_minibatch(src, tgt, src_lengths, flipping):
                # F-prop through the model.
                #TODO pass in the WALS features to self.model()
                outputs, attns, _ = self.model(src,
                                               tgt,
                                               src_lengths,
                                               dec_state=None,
                                               flipping=flipping)

                # Compute loss.
                batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

                # Update statistics.
                stats.update(batch_stats)

            if min(batch.indices) > (len(batch.dataset.examples) // 2) - 1:
                # we are in the upper half
                flipping = True
                process_minibatch(src, tgt, src_lengths, flipping)
            elif max(batch.indices) < (len(batch.dataset.examples) // 2) - 1:
                # we are in the lower half
                flipping = False
                process_minibatch(src, tgt, src_lengths, flipping)
            else:
                # we are in gray territory :O
                # skip minibatch
                continue

        # Set model back to training mode.
        self.model.train()

        return stats
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            elif self.data_type == 'audio':
                src_lengths = batch.src_lengths
            else:
                src_lengths = None

            knl = inputters.make_features(batch, 'knl', self.data_type)
            _, knl_lengths = batch.knl

            tgt = inputters.make_features(batch, 'tgt')

            # F-prop through the model.
            # TODO: apply to all_acts
            if self.model_mode in ['default', 'top_act']:
                first_outputs, first_attns, second_outputs, second_attns = self.model(
                    knl, src, tgt, src_lengths, knl_lengths,
                    src_da_label=(batch.src_da_label[:, 0], batch.src_da_label[:, 1], batch.src_da_label[:, 2]),
                    tgt_da_label=(batch.tgt_da_label,)
                )
            elif self.model_mode in ['all_acts']:
                first_outputs, first_attns, second_outputs, second_attns = self.model(
                    knl, src, tgt, src_lengths, knl_lengths,
                    src_da_label=(
                        batch.src_da_label[:, 0], batch.src_da_label[:, 1], batch.src_da_label[:, 2], batch.src_da_label[:, 3],
                        batch.src_da_label[:, 4], batch.src_da_label[:, 5], batch.src_da_label[:, 6], batch.src_da_label[:, 7],
                        batch.src_da_label[:, 8], batch.src_da_label[:, 9], batch.src_da_label[:, 10], batch.src_da_label[:, 11]
                    ),
                    tgt_da_label=(batch.tgt_da_label,)
                )

            # Compute loss.
            #batch_stats1 = self.valid_loss.monolithic_compute_loss(
            #    batch, first_outputs, first_attns)
            batch_stats2 = self.valid_loss.monolithic_compute_loss(
                batch, second_outputs, second_attns)

            # Update statistics.
            #stats.update(batch_stats1)
            stats.update(batch_stats2)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #16
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        #logger.info("inside validate")
        for batch in valid_iter:
            #logger.info("batch in valid iter")
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = inputters.make_features(batch, 'src', self.data_type)
            ans = inputters.make_features(batch, 'ans', self.data_type)
            #logger.info("self.data_type " + self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
                #### Modified #######
                #logger.info(batch.src[0].size())
                #logger.info(batch.src[1].size())
                #logger.info("batch ans")
                #logger.info(batch.ans[0])
                #logger.info(batch.ans[1].size())
                _, ans_lengths = batch.ans
                #####################
            else:
                src_lengths = None
                ###### Modified #######
                ans_lengths = None
                #######################

            tgt = inputters.make_features(batch, 'tgt')
            # F-prop through the model.
            #src, ans, tgt, lengths, dec_state = None)
            outputs, attns, _ = self.model(src, ans, tgt, src_lengths, ans_lengths)

            #logger.info("outputs of model")
            #logger.info(outputs)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns, train=False)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
    def validate(self, lambda_, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        with torch.no_grad():
            content_stats = onmt.utils.Statistics()
            style_stats = onmt.utils.Statistics()

            for _, batch in enumerate(valid_iter):           

                src = inputters.make_features(batch, 'src', self.data_type)
                if self.data_type == 'text':
                    _, src_lengths = batch.src
                elif self.data_type == 'audio':
                    src_lengths = batch.src_lengths
                else:
                    src_lengths = None

                tgt = inputters.make_features(batch, 'tgt')

                # reference batch
                ref_src = inputters.make_features(batch, 'ref_src', self.data_type)
                if self.data_type == 'text':
                    _, ref_src_lengths = batch.ref_src
                elif self.data_type == 'audio':
                    ref_src_lengths = batch.ref_src_lengths
                else:
                    ref_src_lengths = None

                ref_tgt = inputters.make_features(batch, 'ref_tgt')
                ref_tgt_lengths = None

                # F-prop through the model.
                outputs, attns, back_outputs, back_attns, ref_outputs, ref_attns = self.model(src, batch.src_map, ref_src, tgt, ref_tgt, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths)

                # Compute loss.
                batch_content_stats, batch_style_stats = self.valid_loss.monolithic_compute_loss(
                    lambda_, batch, outputs, back_outputs, ref_outputs, attns, back_attns, ref_attns)

                # Update statistics.
                content_stats.update(batch_content_stats)
                style_stats.update(batch_style_stats)

        # Set model back to training mode.
        self.model.train()

        return content_stats, style_stats
Beispiel #18
0
    def _gradient_accumulation(self, true_batchs, total_stats, report_stats,
                               normalization):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)
            # Truncated BPTT
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            dec_state = None
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
                report_stats.n_src_words += src_lengths.sum().item()
            else:
                src_lengths = None

            tgt_outer = inputters.make_features(batch, 'tgt')

            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.grad_accum_count == 1:
                    self.model.zero_grad()
                outputs, attns, dec_state = \
                    self.model(src, tgt, src_lengths, dec_state)

                # 3. Compute loss in shards for memory efficiency.
                batch_stats = self.train_loss.sharded_compute_loss(
                    batch, outputs, attns, j, trunc_size, self.shard_size,
                    normalization)

                # 4. Update the parameters and statistics.
                if self.grad_accum_count == 1:
                    self.optim.step()
                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

                # If truncated, don't backprop fully.
                if dec_state is not None and j + trunc_size < target_size - 1:
                    dec_state.detach()

        if self.grad_accum_count > 1:
            self.optim.step()
Beispiel #19
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            elif self.data_type == 'audio':
                src_lengths = batch.src_lengths
            else:
                src_lengths = None

            if self.refer:
                ref = inputters.make_features(batch, 'ref', self.data_type)
                ref = (ref, batch.ref[1])
            else:
                ref = None
            tgt = inputters.make_features(batch, 'tgt')

            if self.use_mask:
                length = src.size(0)
                indices = batch.indices.tolist()
                masks = [self.valid_masks[i][:length] + [0] * (length - len(self.valid_masks[i])) for i in indices]
                masks = torch.tensor(masks, device=src.device).t().unsqueeze(-1)
            else:
                masks = None

            # F-prop through the model.
            outputs, attns = self.model(src, tgt, src_lengths, ref, masks=masks)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #20
0
    def _run_encoder(self, batch, data_type):
        src = inputters.make_features(batch, 'src', data_type)
        src_lengths = None
        if data_type == 'text':
            _, src_lengths = batch.src
        elif data_type == 'audio':
            src_lengths = batch.src_lengths
        enc_states, memory_bank, src_lengths, enc_out = self.model.encoder(
            src, src_lengths)

        ctc_scores = None
        if self.ctc_ratio > 0:
            batch_size = enc_out.size(1)
            bottled_enc_out = self._bottle(enc_out)
            ctc_scores = self.model.encoder.ctc_gen(bottled_enc_out)
            ctc_scores = ctc_scores.view(-1, batch_size, ctc_scores.size(-1))

        if src_lengths is None:
            assert not isinstance(memory_bank, tuple), \
                'Ensemble decoding only supported for text data'
            src_lengths = torch.Tensor(batch.batch_size) \
                               .type_as(memory_bank) \
                               .long() \
                               .fill_(memory_bank.size(0))
        return src, enc_states, memory_bank, src_lengths, ctc_scores
Beispiel #21
0
    def _selector_gradient_accumulation(self, true_batchs, normalization,
                                        total_stats, report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            src = inputters.make_features(
                batch, 'src',
                self.data_type)  # [src_len, batch_size, num_features]

            _, src_lengths = batch.src

            # 1. F-prop all.
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            logits, probs = self.model(src, src_lengths)

            # 2. Compute loss in shards for memory efficiency.
            # sharded_compute_loss(self, batch, sel_outputs, sel_probs, dec_outputs, attns, normalization)
            batch_stats = self.train_loss.sharded_compute_loss(
                batch, logits, probs, None, None, normalization)
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

        assert self.n_gpu == 1
        # 3. Update the parameters and statistics.
        self.optim.step(self.cur_valid_ppl)
        report_stats.lr_rate = self.optim.learning_rate
        report_stats.total_norm = self.optim.total_norm
Beispiel #22
0
    def _score_target(self, batch, his_memory_bank, src_memory_bank,
                      knl_memory_bank, knl, src_lengths, data, src_map):
        tgt_in = inputters.make_features(batch, 'tgt')[:-1]

        first_dec_out, first_attns = self.model.decoder(tgt_in,
                                                        src_memory_bank,
                                                        his_memory_bank,
                                                        memory_lengths=None)
        # log_probs [tgt_len, batch_size, vocab_size]
        first_log_probs = self.model.generator(first_dec_out.squeeze(0))
        _, first_dec_words = torch.max(first_log_probs, 2)
        first_dec_words = first_dec_words.unsqueeze(2)
        self.model.decoder2.init_state(first_dec_words, knl[600:, :, :], None,
                                       None)
        emb, decode1_bank, decode1_mask = self.model.encoder.histransformer(
            first_dec_words, None)
        second_dec_out, attn = self.model.decoder2(tgt_in,
                                                   decode1_bank,
                                                   knl_memory_bank,
                                                   memory_lengths=None)
        log_probs = self.model.generator(second_dec_out.squeeze(0))
        #first_dec_out, attn = self.model.decoder(tgt_in, his_memory_bank, src_memory_bank,
        #                                                memory_lengths=None)
        #log_probs = self.model.generator(first_dec_out.squeeze(0))

        tgt_pad = self.fields["tgt"].vocab.stoi[self.fields["tgt"].pad_token]

        log_probs[:, :, tgt_pad] = 0
        gold = batch.tgt[1:].unsqueeze(2)
        gold_scores = log_probs.gather(2, gold)
        gold_scores = gold_scores.sum(dim=0).view(-1)

        return gold_scores
Beispiel #23
0
    def _generator_gradient_accumulation(self, true_batchs, normalization,
                                         total_stats, report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:

            dec_state = None
            src = inputters.make_features(
                batch, 'src',
                self.data_type)  # [src_len, batch_size, num_features]
            # add for RK
            retrieved_keys = inputters.make_features(
                batch, 'retrieved_keys',
                self.data_type)  # [rk_len, batch_size, num_features]
            if self.data_type == 'text':
                _, src_lengths = batch.src
                _, rk_lengths = batch.retrieved_keys
                # report_stats.n_src_words += src_lengths.sum().item()
            else:
                src_lengths = None

            tgt_outer = inputters.make_features(batch, 'tgt')

            # 1. F-prop all but generator.
            if self.grad_accum_count == 1:
                self.model.zero_grad()
            outputs, attns, dec_state = \
                self.model(src, tgt_outer, src_lengths, dec_state, retrieved_keys, rk_lengths)

            # 2. Compute loss.
            # sharded_compute_loss(self, batch, sel_outputs, sel_probs, dec_outputs, attns, normalization)
            batch_stats = self.train_loss.sharded_compute_loss(
                batch=batch,
                sel_outputs=None,
                sel_probs=None,
                dec_outputs=outputs,
                attns=attns,
                normalization=normalization)
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

        # 3. Update the parameters and statistics.
        self.optim.step(self.cur_valid_ppl)
        report_stats.lr_rate = self.optim.learning_rate
        report_stats.total_norm = self.optim.total_norm
Beispiel #24
0
    def _run_target(self, batch, data):
        data_type = data.data_type
        if data_type == 'text':
            _, src_lengths = batch.src
        else:
            src_lengths = None
        src = inputters.make_features(batch, 'src', data_type)
        tgt_in = inputters.make_features(batch, 'tgt')[:-1]

        #  (1) run the encoder on the src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = \
            self.model.decoder.init_decoder_state(src, memory_bank, enc_states)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        tt = torch.cuda if self.cuda else torch
        gold_scores = tt.FloatTensor(batch.batch_size).fill_(0)
        dec_out, _, attn = self.model.decoder(tgt_in,
                                              memory_bank,
                                              dec_states,
                                              memory_lengths=src_lengths)

        tgt_pad = self.fields["tgt"].vocab.stoi[inputters.PAD_WORD]
        for i, (dec, tgt) in enumerate(zip(dec_out, batch.tgt[1:].data)):
            # Log prob of each word.
            if not self.copy_attn:
                out = self.model.generator.forward(dec)
            else:
                out = self.model.generator.forward(dec, attn["copy"][i],
                                                   batch.src_map)
                # data.collapse_copy_scores is used to seeing beam search
                # shaped data
                out = out.unsqueeze(0)
                out = data.collapse_copy_scores(out, batch,
                                                self.fields["tgt"].vocab,
                                                data.src_vocabs)
                out = out.squeeze(0)
            out = out.log()
            tgt = tgt.unsqueeze(1)
            scores = out.data.gather(1, tgt)
            scores.masked_fill_(tgt.eq(tgt_pad), 0)
            gold_scores += scores.view(-1)
        return gold_scores
Beispiel #25
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        torch.cuda.empty_cache()
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            elif self.data_type == 'audio':
                src_lengths = batch.src_lengths
            else:
                src_lengths = None

            knl = inputters.make_features(batch, 'knl', self.data_type)
            _, knl_lengths = batch.knl

            tgt = inputters.make_features(batch, 'tgt')

            # F-prop through the model.
            first_outputs, first_attns, second_outputs, second_attns = self.model(
                src, knl, tgt)

            # Compute loss.
            #batch_stats1 = self.valid_loss.monolithic_compute_loss(
            #    batch, first_outputs, first_attns)
            batch_stats2 = self.valid_loss.monolithic_compute_loss(
                batch, second_outputs, second_attns)

            # Update statistics.
            #stats.update(batch_stats1)
            stats.update(batch_stats2)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #26
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            if self.s_optim is not None:
                src_session = inputters.make_features(batch, 'src_item_sku')
                user = inputters.make_features(batch, 'user')
                stm = inputters.make_features(batch, 'stm')
                _, session_lengths = batch.src_item_sku
            else:
                src_session = None
                user = None
                stm = None
                session_lengths = None
            # tgt_session = inputters.make_features(batch, 'tgt_item_sku')

            src = inputters.make_features(batch, 'src', self.data_type)
            _, src_lengths = batch.src

            tgt = inputters.make_features(batch, 'tgt')

            # F-prop through the model.
            click_score, outputs, attns, _ = self.model(
                src_session, user, stm, src, tgt, session_lengths, src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, click_score, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Beispiel #27
0
    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        stats = onmt.utils.Statistics()
        batch_size = batch.batch_size
        data_type = data.data_type
        vocab = self.fields["tgt"].vocab

        # Define a list of tokens to exclude from ngram-blocking
        # exclusion_list = ["<t>", "</t>", "."]
        exclusion_tokens = set([vocab.stoi[t]
                                for t in self.ignore_when_blocking])


        # Help functions for working with beams and batches
        def var(a): return torch.tensor(a, requires_grad=False)

        # (1) Run the encoder on the src.
        src = inputters.make_features(batch, 'src', data_type)
        ########## Modified #####################
        ans = inputters.make_features(batch, 'ans', data_type)
        tgt = inputters.make_features(batch, 'tgt', data_type)
        #####################################

        src_lengths = None
        ans_lengths = None
        if data_type == 'text':
            _, src_lengths = batch.src
            ############# Modified ###############
            _, ans_lengths = batch.ans
            ####################################

        outputs, attns, _ = self.model(src, ans, tgt, src_lengths, ans_lengths)

        # Compute loss.
        batch_stats = self.valid_loss.monolithic_compute_loss(
            batch, outputs, attns, train=False)

        # Update statistics.
        stats.update(batch_stats)
        ##### TODO: INCORPORATE BLUE AND OTHER STATS ###############
        return stats
Beispiel #28
0
    def _forward_prop(self, batch):
        """forward propagation"""
        # 1, Get all data
        dec_state = None
        src = inputters.make_features(batch, 'src')
        qa = inputters.make_features(batch, 'qa')
        tgt = inputters.make_features(batch, 'tgt')
        _, src_lengths = batch.src
        _, qa_sent_lengths, qa_word_lengths = batch.qa
        # # make word features for qa
        # qa = qa.unsqueeze(-1)

        # 2. F-prop all but generator.
        if self.grad_accum_count == 1:
            self.model.zero_grad()
        outputs, attns, dec_state = \
            self.model(src, src_lengths,
                       qa, qa_sent_lengths, qa_word_lengths,
                       tgt, dec_state)
        return outputs, attns, dec_state
Beispiel #29
0
 def _run_encoder(self, batch, data_type):
     src = inputters.make_features(batch, 'src', data_type)
     knl = inputters.make_features(batch, 'knl', data_type)
     src_lengths = None
     knl_lengths = None
     if data_type == 'text':
         _, src_lengths = batch.src
         _, knl_lengths = batch.knl
     elif data_type == 'audio':
         src_lengths = batch.src_lengths
     enc_states, his_memory_bank, src_memory_bank, knl_memory_bank, src_lengths = self.model.encoder(
         src, knl, src_lengths, knl_lengths)
     if src_lengths is None:
         assert not isinstance(src_memory_bank, tuple), \
             'Ensemble decoding only supported for text data'
         src_lengths = torch.Tensor(batch.batch_size) \
                            .type_as(src_memory_bank) \
                            .long() \
                            .fill_(src_memory_bank.size(0))
     return src, knl, enc_states, his_memory_bank, src_memory_bank, knl_memory_bank, src_lengths
Beispiel #30
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            batch.condition_target = get_emb(batch.condition_target)
            batch.graph = myutils.pad_for_graph(batch.graph,
                                                torch.max(batch.src[1]).item())
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            elif self.data_type == 'audio':
                src_lengths = batch.src_lengths
            else:
                src_lengths = None

            tgt = inputters.make_features(batch, 'tgt')

            # F-prop through the model.
            outputs, attns, _ = self.model(
                (src, batch.graph, batch.condition_target), tgt, src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats