Exemplo n.º 1
0
    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sent_words = [e.src for e in examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda,
                                    append_boundary_sym=True, batch_first=True)
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd > 0.:
            input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -torch.sum(self.decoder.score_decoding_results(tgt_token_scores, tgt_var))
        else:
            reconstruct_loss = -torch.sum(self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=None,
                encoder_hidden=decode_init,
            ))

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }
Exemplo n.º 2
0
    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden
Exemplo n.º 3
0
    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate,
                                    self.vocab.src)

        encoder_output, encoder_hidden = self.encoder.forward(
            input_var, length)
        return encoder_output, encoder_hidden
Exemplo n.º 4
0
Arquivo: ae.py Projeto: XL2248/DSS-VAE
 def encode_var(self, src_var, src_length):
     if self.training and self.word_drop > 0.:
         src_var = unk_replace(src_var,
                               dropoutr=self.word_drop,
                               vocab=self.vocab.src)
     encoder_outputs, encoder_hidden = self.encoder.forward(
         input_var=src_var, input_lengths=src_length)
     encoder_hidden = self.bridger.forward(encoder_hidden)
     return encoder_outputs, encoder_hidden
Exemplo n.º 5
0
    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src)

        _, word_hidden = self.word_encoder.forward(input_var, length)
        _, syntax_hidden = self.syntax_encoder.forward(input_var, length)
        bs = word_hidden.size(1)
        hidden = torch.cat([word_hidden.permute(1, 0, 2).contiguous().view(bs, -1),
                            syntax_hidden.permute(1, 0, 2).contiguous().view(bs, -1)], dim=-1)

        return _, hidden
Exemplo n.º 6
0
    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = ret['tgt_var']
        syntax_output = ret['syn_output']
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -self.decoder.score_decoding_results(
                tgt_token_scores, tgt_var)
        else:
            reconstruct_loss = -self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
            )

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }
Exemplo n.º 7
0
    def encode(self, seqs_x, seqs_length=None):
        """
        Args:
            seqs_x: torch.LongTensor<batch,seq_len(max)>
            seqs_length: truly sequence length

        Returns:

        """
        if self.training and self.word_drop > 0.:
            seqs_x = unk_replace(seqs_x,
                                 dropoutr=self.word_drop,
                                 vocab=self.vocab.src)

        if self.args.enc_type == "att":
            enc_ret = self.encoder.forward(seqs_x)
        else:
            enc_out, enc_hid = self.encoder.forward(seqs_x, None)
            enc_ret = {
                'out': enc_out,
                'hidden': enc_hid,
                'mask': seqs_x.detach().eq(self.vocab.src.pad_id)
            }
        return enc_ret
Exemplo n.º 8
0
    def forward(self, examples, is_dis=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)

        words = [e.src for e in examples]
        tgt_var = to_input_variable(words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(syn_seqs,
                                    self.vocab.tgt,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syntax_hidden=syn_hidden,
                semantic_hidden=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var)
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_log_score = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(inputs=tgt_var,
                                   encoder_outputs=None,
                                   encoder_hidden=sentence_decode_init))

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss

        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret