Esempio n. 1
0
    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden
Esempio n. 2
0
 def encode(self, input_var, length):
     if self.training and self.args.src_wd:
         input_var = unk_replace(input_var, self.step_unk_rate,
                                 self.vocab.src)
     encoder_output, encoder_hidden = self.encoder.forward(
         input_var, length)
     return encoder_output, encoder_hidden
Esempio n. 3
0
    def encode_to_hidden(self, examples, need_sort=False):
        if not isinstance(examples, list):
            examples = [examples]
        if not need_sort:
            sent_words = [e.src for e in examples]
            length = [len(e.src) for e in examples]
            src_var = to_input_variable(sent_words,
                                        self.vocab.src,
                                        training=False,
                                        cuda=self.args.cuda,
                                        batch_first=True)

            encoder_output, encoder_hidden = self.encode(input_var=src_var,
                                                         length=length)

            return {
                "outputs": encoder_output,
                "hidden": encoder_hidden,
                'length': length,
                'batch_size': len(examples)
            }
        sent_words = [e.src for e in examples]
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encode(sorted_sent_var, sorted_sent_lengths)

        if sent_hidden.dim() > 2:
            hidden = sent_hidden[:, example_old_pos_map, :]
        else:
            hidden = sent_hidden[example_old_pos_map, :]

        return {
            "outputs": None,
            "hidden": hidden,
            'length': sent_lengths,
            'batch_size': batch_size
        }
Esempio n. 4
0
 def encode_var(self, src_var, src_length):
     if self.training and self.word_drop > 0.:
         src_var = unk_replace(src_var,
                               dropoutr=self.word_drop,
                               vocab=self.vocab.src)
     encoder_outputs, encoder_hidden = self.encoder.forward(
         input_var=src_var, input_lengths=src_length)
     encoder_hidden = self.bridger.forward(encoder_hidden)
     return encoder_outputs, encoder_hidden
Esempio n. 5
0
 def encode(self, seqs_x, seqs_length=None):
     if self.training and self.word_drop > 0.:
         seqs_x = unk_replace(seqs_x, dropoutr=self.word_drop, vocab=self.vocab.src)
     if self.args.enc_type == "att":
         enc_ret = self.encoder.forward(seqs_x)
         enc_hid = enc_ret['out']
     else:
         enc_hid, _ = self.encoder.forward(seqs_x, input_lengths=seqs_length)
     return enc_hid.mean(dim=1)
Esempio n. 6
0
    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate,
                                    self.vocab.src)

        _, word_hidden = self.word_encoder.forward(input_var, length)
        _, syntax_hidden = self.syntax_encoder.forward(input_var, length)
        bs = word_hidden.size(1)
        hidden = torch.cat([
            word_hidden.permute(1, 0, 2).contiguous().view(bs, -1),
            syntax_hidden.permute(1, 0, 2).contiguous().view(bs, -1)
        ],
                           dim=-1)

        return _, hidden
Esempio n. 7
0
    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sent_words = [e.src for e in examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd > 0.:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_token_scores, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(
                    inputs=tgt_var,
                    encoder_outputs=None,
                    encoder_hidden=decode_init,
                ))

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }
Esempio n. 8
0
    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = ret['tgt_var']
        syntax_output = ret['syn_output']
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -self.decoder.score_decoding_results(
                tgt_token_scores, tgt_var)
        else:
            reconstruct_loss = -self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
            )

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }
    def forward(self, examples, is_dis=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)

        words = [e.src for e in examples]
        tgt_var = to_input_variable(words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(syn_seqs,
                                    self.vocab.tgt,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syntax_hidden=syn_hidden,
                semantic_hidden=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var)
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_log_score = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(inputs=tgt_var,
                                   encoder_outputs=None,
                                   encoder_hidden=sentence_decode_init))

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss

        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret