Ejemplo n.º 1
0
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(BaseSeq2seq, self).__init__()
        self.vocab = vocab
        self.src_vocab = vocab.src
        self.tgt_vocab = vocab.tgt
        self.args = args
        self.encoder = RNNEncoder(vocab_size=len(self.src_vocab),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.enc_factor = 2 if args.bidirectional else 1
        self.enc_dim = args.enc_hidden_dim * self.enc_factor

        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(self.tgt_vocab),
            max_len=args.tgt_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=tgt_embed,
            eos_id=self.tgt_vocab.eos_id,
            sos_id=self.tgt_vocab.sos_id,
        )

        self.beam_decoder = TopKDecoder(decoder_rnn=self.decoder,
                                        k=args.sample_size)
        print("enc layer: {}, dec layer: {}, type: {}, with attention: {}".
              format(args.enc_num_layers, args.dec_num_layers, args.rnn_type,
                     args.use_attention))
Ejemplo n.º 2
0
    def __init__(self, args, vocab, word_embed=None):
        super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE')
        print("This is {} with parameter\n{}".format(self.name, self.base_information()))
        if word_embed is None:
            src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            src_embed = word_embed

        if args.share_embed:
            tgt_embed = src_embed
            args.dec_embed_dim = args.enc_embed_dim
        else:
            tgt_embed = None

        self.latent_size = args.latent_size
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers
        args.use_attention = False

        # layer size setting
        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)  # single layer unit size
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.encoder = RNNEncoder(
            vocab_size=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.enc_embed_dim,
            hidden_size=args.enc_hidden_dim,
            embed_droprate=args.enc_ed,
            rnn_droprate=args.enc_rd,
            n_layers=args.enc_num_layers,
            bidirectional=args.bidirectional,
            rnn_cell=args.rnn_type,
            variable_lengths=True,
            embedding=src_embed
        )

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=False,
            embedding=tgt_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

        self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size)
        self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size)
        self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor)
Ejemplo n.º 3
0
class VanillaVAE(BaseVAE):
    def __init__(self, args, vocab, word_embed=None):
        super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE')
        print("This is {} with parameter\n{}".format(self.name, self.base_information()))
        if word_embed is None:
            src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            src_embed = word_embed

        if args.share_embed:
            tgt_embed = src_embed
            args.dec_embed_dim = args.enc_embed_dim
        else:
            tgt_embed = None

        self.latent_size = args.latent_size
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else 1) * args.enc_num_layers
        args.use_attention = False

        # layer size setting
        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)  # single layer unit size
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.encoder = RNNEncoder(
            vocab_size=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.enc_embed_dim,
            hidden_size=args.enc_hidden_dim,
            embed_droprate=args.enc_ed,
            rnn_droprate=args.enc_rd,
            n_layers=args.enc_num_layers,
            bidirectional=args.bidirectional,
            rnn_cell=args.rnn_type,
            variable_lengths=True,
            embedding=src_embed
        )

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=False,
            embedding=tgt_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

        self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size)
        self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor, args.latent_size)
        self.latent2hidden = nn.Linear(args.latent_size, args.hidden_size * self.hidden_factor)

    def encode(self, input_var, length):
        if self.training and self.args.src_wd:
            input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src)
        encoder_output, encoder_hidden = self.encoder.forward(input_var, length)
        return encoder_output, encoder_hidden

    def decode(self, inputs, encoder_outputs, encoder_hidden):
        return self.decoder.forward(
            inputs=inputs,
            encoder_outputs=encoder_outputs,
            encoder_hidden=encoder_hidden
        )

    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sent_words = [e.src for e in examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = to_input_variable(sent_words, self.vocab.src, training=False, cuda=self.args.cuda,
                                    append_boundary_sym=True, batch_first=True)
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd > 0.:
            input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -torch.sum(self.decoder.score_decoding_results(tgt_token_scores, tgt_var))
        else:
            reconstruct_loss = -torch.sum(self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=None,
                encoder_hidden=decode_init,
            ))

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }

    def get_loss(self, examples, step):
        self.step_unk_rate = wd_anneal_function(unk_max=self.unk_rate, anneal_function=self.args.unk_schedule,
                                                step=step, x0=self.args.x0,
                                                k=self.args.k)
        explore = self.forward(examples)
        batch_size = explore['batch_size']
        kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step)
        kl_weight *= self.args.kl_factor
        nll_loss = explore['nll_loss'] / batch_size
        kl_loss = kl_loss / batch_size
        kl_item = kl_loss * kl_weight
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'KL Weight': kl_weight,
            'Model Score': kl_loss + nll_loss,
            'ELBO': kl_item + nll_loss,
            'Loss': kl_item + nll_loss,
            'KL Item': kl_item,
        }

    def sample_latent(self, batch_size):
        z = to_var(torch.randn([batch_size, self.latent_size]))
        return {
            "latent": z
        }

    def latent_for_init(self, ret):
        z = ret['latent']
        batch_size = z.size(0)
        hidden = self.latent2hidden(z)

        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_factor, self.hidden_size)
            hidden = hidden.permute(1, 0, 2)
        else:
            hidden = hidden.unsqueeze(0)
        ret['decode_init'] = hidden
        return ret

    def batch_beam_decode(self, examples):
        raise NotImplementedError

    def hidden_to_latent(self, ret, is_sampling=True):
        hidden = ret['hidden']
        batch_size = hidden.size(1)
        hidden = hidden.permute(1, 0, 2).contiguous()
        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor)
        else:
            hidden = hidden.squeeze()
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        if is_sampling:
            std = torch.exp(0.5 * logv)
            z = to_var(torch.randn([batch_size, self.latent_size]))
            z = z * std + mean
        else:
            z = mean
        ret["latent"] = z
        ret["mean"] = mean
        ret['logv'] = logv
        return ret

    def conditional_generating(self, condition='sem', examples=None):
        if not isinstance(examples, list):
            examples = [examples]
        if condition.startswith("sem"):
            ret = self.encode_to_hidden(examples)
            ret = self.hidden_to_latent(ret=ret, is_sampling=True)
            ret = self.latent_for_init(ret=ret)
            return {
                'res': self.decode_to_sentence(ret=ret)
            }
        if condition is None:
            return {
                "res": self.unsupervised_generating(sample_num=len(examples))
            }

    def eval_adv(self, sem_in, syn_ref):
        sem_ret = self.encode_to_hidden(sem_in)
        sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training)
        syn_ret = self.encode_to_hidden(syn_ref, need_sort=True)
        syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training)
        sem_ret = self.latent_for_init(ret=sem_ret)
        syn_ret = self.latent_for_init(ret=syn_ret)
        ret = dict()
        ret["latent"] = (sem_ret['latent'] + syn_ret['latent']) * 0.5
        ret = self.latent_for_init(ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret
Ejemplo n.º 4
0
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(DisentangleVAE,
              self).__init__(args,
                             vocab,
                             name="Disentangle VAE with deep encoder")
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if src_embed is None:
            self.src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            self.src_embed = src_embed
        if tgt_embed is None:
            self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size)
        else:
            self.tgt_embed = tgt_embed

        self.pad_idx = vocab.src.sos_id

        self.latent_size = int(args.latent_size)
        self.rnn_type = args.rnn_type
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.direction_num = 2 if args.bidirectional else 1

        self.enc_hidden_dim = args.enc_hidden_dim
        self.enc_layer_dim = args.enc_hidden_dim * self.direction_num
        self.enc_hidden_factor = self.direction_num * args.enc_num_layers
        self.dec_hidden_factor = args.dec_num_layers
        args.use_attention = False
        if args.mapper_type == "link":
            self.dec_layer_dim = self.enc_layer_dim
        elif args.use_attention:
            self.dec_layer_dim = self.enc_layer_dim
        else:
            self.dec_layer_dim = args.dec_hidden_dim

        syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)
        sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)

        task_enc_dim = int(self.enc_layer_dim / 2)
        task_dec_dim = int(self.dec_layer_dim / 2)

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=self.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=self.src_embed)
        # output: [layer*direction ,batch_size, enc_hidden_dim]

        pack_decoder = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=self.enc_layer_dim,
            dec_hidden_dim=self.dec_layer_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
        self.bridger = pack_decoder.bridger
        self.decoder = pack_decoder.decoder

        if "report" in self.args:
            syn_common = nn.Sequential(
                nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.syn_mean = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.syn_logv = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))

            sem_common = nn.Sequential(
                nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.sem_mean = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.sem_logv = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
        else:
            self.syn_mean = nn.Linear(syn_var_dim, self.latent_size)
            self.syn_logv = nn.Linear(syn_var_dim, self.latent_size)
            self.sem_mean = nn.Linear(sem_var_dim, self.latent_size)
            self.sem_logv = nn.Linear(sem_var_dim, self.latent_size)

        self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim)
        self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim)

        self.sup_syn = BridgeRNN(args,
                                 vocab,
                                 enc_hidden_dim=task_enc_dim,
                                 dec_hidden_dim=task_dec_dim,
                                 embed=tgt_embed,
                                 mode='tgt')

        self.sup_sem = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.syn_adv = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.tgt_embed if args.share_embed else None,
            mode='tgt')

        self.syn_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

        self.sem_adv = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
Ejemplo n.º 5
0
class DisentangleVAE(BaseVAE):
    """
    Encoder the sentence, predict the parser,
    """
    def score(self, **kwargs):
        pass

    def decode(self, inputs, encoder_outputs, encoder_hidden):
        return self.decoder.forward(inputs=inputs,
                                    encoder_outputs=encoder_outputs,
                                    encoder_hidden=encoder_hidden)

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(DisentangleVAE,
              self).__init__(args,
                             vocab,
                             name="Disentangle VAE with deep encoder")
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if src_embed is None:
            self.src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            self.src_embed = src_embed
        if tgt_embed is None:
            self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size)
        else:
            self.tgt_embed = tgt_embed

        self.pad_idx = vocab.src.sos_id

        self.latent_size = int(args.latent_size)
        self.rnn_type = args.rnn_type
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.direction_num = 2 if args.bidirectional else 1

        self.enc_hidden_dim = args.enc_hidden_dim
        self.enc_layer_dim = args.enc_hidden_dim * self.direction_num
        self.enc_hidden_factor = self.direction_num * args.enc_num_layers
        self.dec_hidden_factor = args.dec_num_layers
        args.use_attention = False
        if args.mapper_type == "link":
            self.dec_layer_dim = self.enc_layer_dim
        elif args.use_attention:
            self.dec_layer_dim = self.enc_layer_dim
        else:
            self.dec_layer_dim = args.dec_hidden_dim

        syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)
        sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)

        task_enc_dim = int(self.enc_layer_dim / 2)
        task_dec_dim = int(self.dec_layer_dim / 2)

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=self.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=self.src_embed)
        # output: [layer*direction ,batch_size, enc_hidden_dim]

        pack_decoder = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=self.enc_layer_dim,
            dec_hidden_dim=self.dec_layer_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
        self.bridger = pack_decoder.bridger
        self.decoder = pack_decoder.decoder

        if "report" in self.args:
            syn_common = nn.Sequential(
                nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.syn_mean = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.syn_logv = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))

            sem_common = nn.Sequential(
                nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.sem_mean = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.sem_logv = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
        else:
            self.syn_mean = nn.Linear(syn_var_dim, self.latent_size)
            self.syn_logv = nn.Linear(syn_var_dim, self.latent_size)
            self.sem_mean = nn.Linear(sem_var_dim, self.latent_size)
            self.sem_logv = nn.Linear(sem_var_dim, self.latent_size)

        self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim)
        self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim)

        self.sup_syn = BridgeRNN(args,
                                 vocab,
                                 enc_hidden_dim=task_enc_dim,
                                 dec_hidden_dim=task_dec_dim,
                                 embed=tgt_embed,
                                 mode='tgt')

        self.sup_sem = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.syn_adv = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.tgt_embed if args.share_embed else None,
            mode='tgt')

        self.syn_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

        self.sem_adv = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

    def base_information(self):
        origin = super().base_information()
        return origin + "mul_syn:{}\n" \
                        "mul_sen:{}\n" \
                        "adv_syn:{}\n" \
                        "adv_sem:{}\n" \
                        "inf_syn:{}\n" \
                        "inf_sem:{}\n" \
                        "kl_syn:{}\n" \
                        "kl_sem:{}\n".format(str(self.args.mul_syn),
                                             str(self.args.mul_sem),
                                             str(self.args.adv_syn),
                                             str(self.args.adv_sem),
                                             str(self.args.inf_syn * self.args.infer_weight),
                                             str(self.args.inf_sem * self.args.infer_weight),
                                             str(self.args.syn_weight),
                                             str(self.args.sem_weight)
                                             )

    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate,
                                    self.vocab.src)

        encoder_output, encoder_hidden = self.encoder.forward(
            input_var, length)
        return encoder_output, encoder_hidden

    def forward(self, examples, is_dis=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)

        words = [e.src for e in examples]
        tgt_var = to_input_variable(words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(syn_seqs,
                                    self.vocab.tgt,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syntax_hidden=syn_hidden,
                semantic_hidden=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var)
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_log_score = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(inputs=tgt_var,
                                   encoder_outputs=None,
                                   encoder_hidden=sentence_decode_init))

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss

        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret

    def get_loss(self, examples, step, is_dis=False):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples, is_dis)

        if is_dis:
            return explore

        sem_kl, kl_weight = self.compute_kl_loss(
            mean=explore['sem_mean'],
            logv=explore['sem_logv'],
            step=step,
        )
        syn_kl, _ = self.compute_kl_loss(
            mean=explore['syn_mean'],
            logv=explore['syn_logv'],
            step=step,
        )

        batch_size = explore['batch_size']
        kl_weight *= self.args.kl_factor
        kl_loss = (self.args.sem_weight * sem_kl + self.args.syn_weight *
                   syn_kl) / (self.args.sem_weight + self.args.syn_weight)
        kl_loss /= batch_size
        mul_loss = explore['mul'] / batch_size
        adv_loss = explore['adv'] / batch_size
        nll_loss = explore['nll_loss'] / batch_size
        kl_item = kl_loss * kl_weight

        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'MUL Loss': mul_loss,
            'ADV Loss': adv_loss,
            'KL Weight': kl_weight,
            'KL Item': kl_item,
            'Model Score': kl_loss + nll_loss,
            'ELBO': kl_item + nll_loss,
            'Loss': kl_item + nll_loss + mul_loss - adv_loss,
            'SYN KL Loss': syn_kl / explore['batch_size'],
            'SEM KL Loss': sem_kl / explore['batch_size'],
        }

    def get_adv_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        if self.training:
            with torch.no_grad():
                loss_dict = self._dis_loss(syntax_hidden, semantic_hidden,
                                           syn_tgt, sem_tgt)
            if self.args.infer_weight > 0.:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \
                          loss_dict['adv_sem_inf']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \
                          loss_dict['adv_syn_inf']
            else:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup']
            return adv_syn, adv_sem
        else:
            loss_dict = self._dis_loss(syntax_hidden, semantic_hidden, syn_tgt,
                                       sem_tgt)
            if self.args.infer_weight > 0.:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \
                          loss_dict['adv_sem_inf']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \
                          loss_dict['adv_syn_inf']
            else:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup']

            return adv_syn, adv_sem

    def get_dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        syntax_hid = syntax_hidden.detach()
        semantic_hid = semantic_hidden.detach()

        loss_dict = self._dis_loss(syntax_hid, semantic_hid, syn_tgt, sem_tgt)
        if self.args.infer_weight > 0.:
            return loss_dict['adv_syn_sup'] + loss_dict[
                'adv_sem_inf'], loss_dict['adv_sem_sup'] + loss_dict[
                    'adv_syn_inf']
        else:
            return loss_dict['adv_syn_sup'], loss_dict['adv_sem_sup']

    def _dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        dis_syn_sup = self.syn_adv.forward(hidden=semantic_hidden,
                                           tgt_var=syn_tgt)
        dis_sem_sup = self.sem_adv.forward(hidden=syntax_hidden,
                                           tgt_var=sem_tgt)
        if self.args.infer_weight > 0.:

            dis_syn_inf = self.syn_infer.forward(hidden=syntax_hidden,
                                                 tgt_var=sem_tgt)
            dis_sem_inf = self.sem_infer.forward(hidden=semantic_hidden,
                                                 tgt_var=sem_tgt)
            return {
                'adv_syn_sup': dis_syn_sup if self.args.adv_syn > 0. else 0.,
                'adv_sem_sup': dis_sem_sup if self.args.adv_sem > 0. else 0.,
                'adv_syn_inf': dis_syn_inf if self.args.inf_syn > 0. else 0.,
                "adv_sem_inf": dis_sem_inf if self.args.inf_sem > 0. else 0.
            }
        else:
            return {
                'adv_syn_sup': dis_syn_sup,
                'adv_sem_sup': dis_sem_sup,
            }

    def get_mul_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        syn_loss = self.sup_syn.forward(hidden=syntax_hidden, tgt_var=syn_tgt)
        sem_loss = self.sup_sem.forward(hidden=semantic_hidden,
                                        tgt_var=sem_tgt)
        return self.args.mul_syn * syn_loss, self.args.mul_sem * sem_loss

    def sample_latent(self, batch_size):
        syntax_latent = to_var(torch.randn([batch_size, self.latent_size]))
        semantic_latent = to_var(torch.randn([batch_size, self.latent_size]))
        return {
            "syn_z": syntax_latent,
            "sem_z": semantic_latent,
        }

    def hidden_to_latent(self, ret, is_sampling=True):
        hidden = ret['hidden']

        def sampling(mean, logv):
            if is_sampling:
                std = torch.exp(0.5 * logv)
                z = to_var(torch.randn([batch_size, self.latent_size]))
                z = z * std + mean
            else:
                z = mean
            return z

        def split_hidden(encode_hidden):
            bs = encode_hidden.size(1)
            factor = encode_hidden.size(0)
            hid = encode_hidden.permute(1, 0, 2).contiguous().view(
                bs, factor, 2, -1)
            return hid[:, :, 0, :].contiguous().view(
                bs, -1), hid[:, :, 1, :].contiguous().view(bs, -1)

        batch_size = hidden.size(1)
        sem_hid, syn_hid = split_hidden(hidden)

        semantic_mean = self.sem_mean(sem_hid)
        semantic_logv = self.sem_logv(sem_hid)
        syntax_mean = self.syn_mean(syn_hid)
        syntax_logv = self.syn_logv(syn_hid)
        syntax_latent = sampling(syntax_mean, syntax_logv)
        semantic_latent = sampling(semantic_mean, semantic_logv)

        ret['syn_mean'] = syntax_mean
        ret['syn_logv'] = syntax_logv
        ret['sem_mean'] = semantic_mean
        ret['sem_logv'] = semantic_logv
        ret['syn_z'] = syntax_latent
        ret['sem_z'] = semantic_latent

        return ret

    def latent_for_init(self, ret):
        def reshape(xx_hidden):
            xx_hidden = xx_hidden.view(batch_size, self.enc_hidden_factor,
                                       self.enc_hidden_dim / 2)
            xx_hidden = xx_hidden.permute(1, 0, 2)
            return xx_hidden

        syntax_latent = ret['syn_z']
        semantic_latent = ret['sem_z']
        batch_size = semantic_latent.size(0)
        syntax_hidden = reshape(self.syn_to_h(syntax_latent))
        semantic_hidden = reshape(self.sem_to_h(semantic_latent))

        ret['syn_hidden'] = syntax_hidden
        ret['sem_hidden'] = semantic_hidden
        ret['decode_init'] = torch.cat([syntax_hidden, semantic_hidden],
                                       dim=-1)
        return ret

    def evaluate_(self, examples, beam_size=5):
        if not isinstance(examples, list):
            examples = [examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret

    def predict_syntax(self, hidden, predictor):
        result = predictor.predict(hidden)
        numbers = result.size(1)
        final_result = []
        for i in range(numbers):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.tgt)
            seems = [[res], [len(res)]]
            final_result.append(seems)
        return final_result

    def extract_variable(self, examples):
        pass

    def eval_syntax(self, examples):
        ret = self.encode_to_hidden(examples, need_sort=True)
        ret = self.hidden_to_latent(ret, is_sampling=False)
        ret = self.latent_for_init(ret)
        return self.predict_syntax(hidden=ret['syn_hidden'],
                                   predictor=self.sup_syn)

    def eval_adv(self, sem_in, syn_ref):
        sem_ret = self.encode_to_hidden(sem_in)
        sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training)
        syn_ret = self.encode_to_hidden(syn_ref, need_sort=True)
        syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training)
        sem_ret = self.latent_for_init(ret=sem_ret)
        syn_ret = self.latent_for_init(ret=syn_ret)
        ret = dict(sem_z=sem_ret['sem_z'], syn_z=syn_ret['syn_z'])
        ret = self.latent_for_init(ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        ret['ori syn'] = self.predict_syntax(hidden=sem_ret['syn_hidden'],
                                             predictor=self.sup_syn)
        ret['ref syn'] = self.predict_syntax(hidden=syn_ret['syn_hidden'],
                                             predictor=self.sup_syn)
        return ret

    def conditional_generating(self, condition="sem", examples=None):
        ref_ret = self.encode_to_hidden(examples)
        ref_ret = self.hidden_to_latent(ref_ret, is_sampling=True)
        if condition.startswith("sem"):
            ref_ret['sem_z'] = ref_ret['sem_mean']
        else:
            ref_ret['syn_z'] = ref_ret['syn_mean']

        if condition == "sem-only":
            sam_ref = self.sample_latent(batch_size=ref_ret['batch_size'])
            ref_ret['syn_z'] = sam_ref['syn_z']

        ret = self.latent_for_init(ret=ref_ret)

        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret
Ejemplo n.º 6
0
class BaseSeq2seq(nn.Module, BaseGenerator):
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(BaseSeq2seq, self).__init__()
        self.vocab = vocab
        self.src_vocab = vocab.src
        self.tgt_vocab = vocab.tgt
        self.args = args
        self.encoder = RNNEncoder(vocab_size=len(self.src_vocab),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.enc_factor = 2 if args.bidirectional else 1
        self.enc_dim = args.enc_hidden_dim * self.enc_factor

        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(self.tgt_vocab),
            max_len=args.tgt_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=tgt_embed,
            eos_id=self.tgt_vocab.eos_id,
            sos_id=self.tgt_vocab.sos_id,
        )

        self.beam_decoder = TopKDecoder(decoder_rnn=self.decoder,
                                        k=args.sample_size)
        print("enc layer: {}, dec layer: {}, type: {}, with attention: {}".
              format(args.enc_num_layers, args.dec_num_layers, args.rnn_type,
                     args.use_attention))

    def get_loss(self, **kwargs):
        return {"Loss": -self.score(**kwargs)}

    def forward(self, seqs_x, x_length, to_word=False):
        pass

    def init(self):
        self.encoder.rnn.flatten_parameters()
        self.decoder.rnn.flatten_parameters()

    def encode(self, src_var, src_length):
        encoder_outputs, encoder_hidden = self.encoder.forward(
            input_var=src_var, input_lengths=src_length)
        return encoder_outputs, encoder_hidden

    def bridge(self, encoder_hidden):
        # batch_size = encoder_hidden.size(1)
        # convert = encoder_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1)
        return self.bridger.forward(encoder_hidden)

    def get_hidden(self, examples):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        return encoder_hidden

    def score(self, examples, return_enc_state=False, **kwargs):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        scores = self.decoder.score(inputs=tgt_var,
                                    encoder_hidden=encoder_hidden,
                                    encoder_outputs=encoder_outputs)

        if return_enc_state:
            return scores, encoder_hidden
        else:
            return scores

    def predict(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=False,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)

        decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            teacher_forcing_ratio=0.0)
        result = torch.stack(ret_dict['sequence']).squeeze()
        final_result = []
        if len(result.size()) < 2:
            result = result.view(-1, 1)
        example_nums = result.size(-1)
        if to_word:
            for i in range(example_nums):
                hyp = result[:, i].data.tolist()
                res = id2word(hyp, self.vocab.tgt)
                seems = [[res], [len(res)]]
                final_result.append(seems)
        return final_result

    def beam_search(self, src_sent, beam_size=5, dmts=None):
        if dmts is None:
            dmts = self.args.decode_max_time_step
        src_var = to_input_variable(src_sent,
                                    self.src_vocab,
                                    cuda=self.args.cuda,
                                    training=False,
                                    append_boundary_sym=False,
                                    batch_first=True)
        src_length = [len(src_sent)]

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridger.forward(input_tensor=encoder_hidden)
        meta_data = self.beam_decoder.beam_search(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            beam_size=beam_size,
            decode_max_time_step=dmts)
        topk_sequence = meta_data['sequence']
        topk_score = meta_data['score'].squeeze()

        completed_hypotheses = torch.cat(topk_sequence, dim=-1)

        number_return = completed_hypotheses.size(0)
        final_result = []
        final_scores = []
        for i in range(number_return):
            hyp = completed_hypotheses[i, :].data.tolist()
            res = id2word(hyp, self.tgt_vocab)
            final_result.append(res)
            final_scores.append(topk_score[i].item())
        return final_result, final_scores

    def load_state_dict(self, state_dict, strict=True):
        return super().load_state_dict(state_dict, strict)

    def save(self, path):
        dir_name = os.path.dirname(path)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

        params = {
            'args': self.args,
            'vocab': self.vocab,
            'state_dict': self.state_dict(),
        }

        torch.save(params, path)

    @classmethod
    def load(cls, load_path):
        params = torch.load(load_path,
                            map_location=lambda storage, loc: storage)
        args = params['args']
        vocab = params['vocab']
        model = cls(args, vocab)
        model.load_state_dict(params['state_dict'])
        if args.cuda:
            model = model.cuda()
        return model
Ejemplo n.º 7
0
class EnhanceSyntaxVAE(BaseVAE):
    """
    Encoder the sentence, predict the parser,
    """

    def score(self, **kwargs):
        raise NotImplementedError

    def decode(self, inputs, encoder_outputs, encoder_hidden):
        return self.decoder.forward(
            inputs=inputs,
            encoder_outputs=encoder_outputs,
            encoder_hidden=encoder_hidden
        )

    def base_information(self):
        origin = super().base_information()
        return origin + "mul_syn:{}\nmul_sen:{}\nsyn_to_sem:{}\nsem_to_syn:{}\nsyn_to_sent:{}\nsem_to_sent:{}" \
                        "\nkl_syn:{}\nkl_sem:{}\n".format(
            str(self.args.mul_syn),
            str(self.args.mul_sem),
            str(self.args.syn_to_sent),
            str(self.args.sem_to_sent),
            str(self.args.syn_to_sent * self.args.infer_weight),
            str(self.args.sem_to_sent * self.args.infer_weight),
            str(self.args.syn_weight),
            str(self.args.sem_weight)
        )

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(EnhanceSyntaxVAE, self).__init__(args, vocab, name="Syntax VAE with 2 seperate encoder")
        print("This is {} with parameter\n{}".format(self.name, self.base_information()))
        if src_embed is None:
            self.src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            self.src_embed = src_embed
        if tgt_embed is None:
            self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size)
        else:
            self.tgt_embed = tgt_embed

        self.pad_idx = vocab.src.pad_id

        self.latent_size = int(args.latent_size)
        self.rnn_type = args.rnn_type
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.direction_num = 2 if args.bidirectional else 1

        self.enc_hidden_dim = args.enc_hidden_dim
        self.enc_layer_dim = args.enc_hidden_dim * self.direction_num
        self.enc_hidden_factor = self.direction_num * args.enc_num_layers
        self.dec_hidden_factor = args.dec_num_layers
        args.use_attention = False
        if args.mapper_type == "link":
            self.dec_layer_dim = self.enc_layer_dim
        elif args.use_attention:
            self.dec_layer_dim = self.enc_layer_dim
        else:
            self.dec_layer_dim = args.dec_hidden_dim

        syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)
        sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)

        task_enc_dim = int(self.enc_layer_dim / 2)
        task_dec_dim = int(self.dec_layer_dim / 2)

        self.sem_encoder = RNNEncoder(
            vocab_size=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.enc_embed_dim,
            hidden_size=int(self.enc_hidden_dim / 2),
            embed_droprate=args.enc_ed,
            rnn_droprate=args.enc_rd,
            n_layers=args.enc_num_layers,
            bidirectional=args.bidirectional,
            rnn_cell=args.rnn_type,
            variable_lengths=True,
            embedding=self.src_embed
        )

        self.syn_encoder = RNNEncoder(
            vocab_size=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.enc_embed_dim,
            hidden_size=int(self.enc_hidden_dim / 2),
            embed_droprate=args.enc_ed,
            rnn_droprate=args.enc_rd,
            n_layers=args.enc_num_layers,
            bidirectional=args.bidirectional,
            rnn_cell=args.rnn_type,
            variable_lengths=True,
            embedding=self.src_embed
        )

        # output: [layer*direction ,batch_size, enc_hidden_dim]

        pack_decoder = RNNPredictor(
            args,
            vocab,
            enc_hidden_dim=self.enc_layer_dim,
            dec_hidden_dim=self.dec_layer_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src'
        )
        self.bridger = pack_decoder.bridger
        self.decoder = pack_decoder.decoder

        self.syn_mean = nn.Linear(syn_var_dim, self.latent_size)
        self.syn_logv = nn.Linear(syn_var_dim, self.latent_size)
        self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim)

        self.sem_mean = nn.Linear(sem_var_dim, self.latent_size)
        self.sem_logv = nn.Linear(sem_var_dim, self.latent_size)
        self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim)

        self.mul_syn = RNNPredictor(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=tgt_embed,
            mode='tgt'
        )

        self.mul_sem = MLPPredictor(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_to_syn = RNNPredictor(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.tgt_embed if args.share_embed else None,
            mode='tgt'
        )

        self.syn_to_sem = MLPPredictor(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_to_sent = RNNPredictor(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src'
        )

        self.syn_to_sent = RNNPredictor(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src'
        )

    def cuda(self, device=None):
        model_list = [
            self.sem_encoder, self.syn_encoder, self.bridger, self.decoder,
            self.syn_mean, self.syn_logv, self.syn_to_h,
            self.sem_mean, self.sem_logv, self.sem_to_h,
            self.mul_syn, self.mul_sem,
            self.sem_to_syn, self.sem_to_sent,
            self.syn_to_sem, self.syn_to_sent
        ]
        for model in model_list:
            device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
            model = torch.nn.DataParallel(model)
            model.to(device)

    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate, self.vocab.src)

        _, sem_hid = self.sem_encoder.forward(input_var, length)
        _, syn_hid = self.syn_encoder.forward(input_var, length)
        bs = sem_hid.size(1)
        hidden = torch.cat(
            [sem_hid.permute(1, 0, 2).contiguous().view(bs, -1),
             syn_hid.permute(1, 0, 2).contiguous().view(bs, -1)],
            dim=-1
        )
        return _, hidden

    def forward(self, examples, is_dis=False, norm_by_word=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        words = [e.src for e in examples]
        tgt_var = to_input_variable(
            words, self.vocab.src, training=False, cuda=self.args.cuda,
            append_boundary_sym=True, batch_first=True
        )
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(
            syn_seqs, self.vocab.tgt, training=False, cuda=self.args.cuda,
            append_boundary_sym=True, batch_first=True
        )

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syn_hid=syn_hidden,
                sem_hid=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var,
                norm_by_word=norm_by_word
            )
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate, self.vocab.src)
            tgt_log_score = self.decoder.decode(
                inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -(self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -(self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=None,
                encoder_hidden=sentence_decode_init))
        if norm_by_word:
            tgt_len = tgt_var.ne(self.pad_idx).sum(dim=-1).float()
            reconstruct_loss = reconstruct_loss.div(tgt_len + 1e-9)

        reconstruct_loss = reconstruct_loss.sum()

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var,
            norm_by_word=norm_by_word
        )

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syn_hid=syn_hidden,
            sem_hid=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var,
            norm_by_word=norm_by_word
        )
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss
        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret

    def get_loss(self, examples, train_iter, is_dis=False, **kwargs):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=train_iter,
            x0=self.args.x0,
            k=self.args.k
        )
        explore = self.forward(examples, is_dis, norm_by_word=False)

        if is_dis:
            return explore

        sem_kl, kl_weight = self.compute_kl_loss(
            mean=explore['sem_mean'],
            logv=explore['sem_logv'],
            step=train_iter,
        )
        syn_kl, _ = self.compute_kl_loss(
            mean=explore['syn_mean'],
            logv=explore['syn_logv'],
            step=train_iter,
        )

        batch_size = explore['batch_size']
        kl_weight *= self.args.kl_factor
        kl_loss = (self.args.sem_weight * sem_kl + self.args.syn_weight * syn_kl) / (
                self.args.sem_weight + self.args.syn_weight)
        kl_loss /= batch_size
        mul_loss = explore['mul'] / batch_size
        adv_loss = explore['adv'] / batch_size
        nll_loss = explore['nll_loss'] / batch_size
        kl_item = kl_loss * kl_weight
        args = self.args
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'MUL Loss': mul_loss,
            'ADV Loss': adv_loss,
            'KL Weight': kl_weight,
            'KL Item': kl_item,
            'Model Score': kl_loss + nll_loss,
            'ELBO': kl_item + nll_loss,
            'Loss': kl_item + nll_loss + mul_loss - adv_loss if train_iter > args.warm_up else kl_item + nll_loss + mul_loss,
            'SYN KL Loss': syn_kl / explore['batch_size'],
            'SEM KL Loss': sem_kl / explore['batch_size'],
        }

    def get_adv_loss(self, syn_hid, sem_hid, syn_tgt, sem_tgt, norm_by_word=False):
        if self.training:
            self.sem_to_sent.detach()
            self.syn_to_sent.detach()
            self.syn_to_sem.detach()
            self.sem_to_syn.detach()
        loss_dict = self._dis_loss(
            syn_hid=syn_hid,
            sem_hid=sem_hid,
            syn_tgt=syn_tgt,
            sem_tgt=sem_tgt,
            norm_by_word=norm_by_word
        )
        if self.args.infer_weight > 0.:
            adv_for_syn = self.args.syn_to_sem * loss_dict['syn_to_sem'] + \
                          self.args.infer_weight * self.args.syn_to_sent * loss_dict['syn_to_sent']
            adv_for_sem = self.args.sem_to_syn * loss_dict['sem_to_syn'] + \
                          self.args.infer_weight * self.args.sem_to_sent * loss_dict['sem_to_sent']
        else:
            adv_for_syn = self.args.syn_to_sem * loss_dict['syn_to_sem']
            adv_for_sem = self.args.sem_to_syn * loss_dict['sem_to_syn']

        return adv_for_syn, adv_for_sem

    def get_dis_loss(self, syn_hid, sem_hid, syn_tgt, sem_tgt, norm_by_word=False):
        self.sem_to_sent.reload()
        self.syn_to_sent.reload()
        self.sem_to_syn.reload()
        self.syn_to_sem.reload()
        loss_dict = self._dis_loss(
            syn_hid=syn_hid.detach(),
            sem_hid=sem_hid.detach(),
            syn_tgt=syn_tgt,
            sem_tgt=sem_tgt,
            norm_by_word=norm_by_word
        )
        if self.args.infer_weight > 0.:
            return loss_dict['syn_to_sem'] + loss_dict['syn_to_sent'], loss_dict['sem_to_syn'] + loss_dict[
                'sem_to_sent']
        else:
            return loss_dict['syn_to_sem'], loss_dict['sem_to_syn']

    def _dis_loss(self, syn_hid, sem_hid, syn_tgt, sem_tgt, norm_by_word=False):
        args = self.args

        sem_to_syn = self.sem_to_syn.forward(
            hidden=sem_hid, tgt_var=syn_tgt, div_by_word=norm_by_word) if args.sem_to_syn > 0. else 0.
        syn_to_sem = self.syn_to_sem.forward(
            hidden=syn_hid, tgt_var=sem_tgt) if args.syn_to_sem > 0. else 0.

        sem_to_sent = self.sem_to_sent.forward(
            hidden=sem_hid, tgt_var=sem_tgt, div_by_word=norm_by_word) if args.sem_to_sent > 0. else 0.
        syn_to_sent = self.syn_to_sent.forward(
            hidden=syn_hid, tgt_var=sem_tgt, div_by_word=norm_by_word) if args.syn_to_sent > 0. else 0.

        return {
            'sem_to_syn': sem_to_syn,
            'syn_to_sem': syn_to_sem,
            'sem_to_sent': sem_to_sent,
            "syn_to_sent": syn_to_sent,
        }

    def get_mul_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt, norm_by_word=False):
        syn_loss = self.mul_syn.forward(hidden=syntax_hidden, tgt_var=syn_tgt, div_by_word=norm_by_word)
        sem_loss = self.mul_sem.forward(hidden=semantic_hidden, tgt_var=sem_tgt)
        return self.args.mul_syn * syn_loss, self.args.mul_sem * sem_loss

    def sample_latent(self, batch_size):
        syntax_latent = to_var(torch.randn([batch_size, self.latent_size]))
        semantic_latent = to_var(torch.randn([batch_size, self.latent_size]))
        return {
            "syn_z": syntax_latent,
            "sem_z": semantic_latent,
        }

    def hidden_to_latent(self, ret, is_sampling=True):
        hidden = ret['hidden']
        batch_size = hidden.size(0)

        def sampling(mean, logv):
            if is_sampling:
                std = torch.exp(0.5 * logv)
                z = to_var(torch.randn([batch_size, self.latent_size]))
                z = z * std + mean
            else:
                z = mean
            return z

        def split_hidden(encode_hidden):
            bs = encode_hidden.size(0)
            hid = encode_hidden.contiguous().view(bs, 2, -1)
            return hid[:, 0, :].contiguous(), hid[:, 1, :].contiguous()

        sem_hid, syn_hid = split_hidden(hidden)

        ret['syn_mean'] = self.syn_mean(syn_hid)
        ret['syn_logv'] = self.syn_logv(syn_hid)
        ret['sem_mean'] = self.sem_mean(sem_hid)
        ret['sem_logv'] = self.sem_logv(sem_hid)
        ret['syn_z'] = sampling(mean=ret['syn_mean'], logv=ret['syn_logv'])
        ret['sem_z'] = sampling(mean=ret['sem_mean'], logv=ret['sem_logv'])

        return ret

    def latent_for_init(self, ret):
        def reshape(xx_hidden):
            xx_hidden = xx_hidden.view(batch_size, self.enc_hidden_factor, int(self.enc_hidden_dim / 2))
            xx_hidden = xx_hidden.permute(1, 0, 2)
            return xx_hidden

        syn_z = ret['syn_z']
        sem_z = ret['sem_z']
        batch_size = sem_z.size(0)
        syn_hidden = reshape(self.syn_to_h(syn_z))
        sem_hidden = reshape(self.sem_to_h(sem_z))

        ret['syn_hidden'] = syn_hidden
        ret['sem_hidden'] = sem_hidden
        ret['decode_init'] = torch.cat([syn_hidden, sem_hidden], dim=-1)
        return ret

    def evaluate_(self, examples, beam_size=5):
        if not isinstance(examples, list):
            examples = [examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret

    def predict_syntax(self, hidden, predictor):
        result = predictor.predict(hidden)
        numbers = result.size(1)
        final_result = []
        for i in range(numbers):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.tgt)
            seems = [[res], [len(res)]]
            final_result.append(seems)
        return final_result

    def eval_syntax(self, examples):
        ret = self.encode_to_hidden(examples, need_sort=True)
        ret = self.hidden_to_latent(ret, is_sampling=False)
        ret = self.latent_for_init(ret)
        return self.predict_syntax(hidden=ret['syn_hidden'], predictor=self.mul_syn)

    def eval_adv(self, sem_in, syn_ref):
        sem_ret = self.encode_to_hidden(sem_in)
        sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training)
        syn_ret = self.encode_to_hidden(syn_ref, need_sort=True)
        syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training)
        sem_ret = self.latent_for_init(ret=sem_ret)
        syn_ret = self.latent_for_init(ret=syn_ret)
        ret = dict(sem_z=sem_ret['sem_z'], syn_z=syn_ret['syn_z'])
        ret = self.latent_for_init(ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        ret['ori syn'] = self.predict_syntax(hidden=sem_ret['syn_hidden'], predictor=self.mul_syn)
        ret['ref syn'] = self.predict_syntax(hidden=syn_ret['syn_hidden'], predictor=self.mul_syn)
        return ret

    def conditional_generating(self, condition="sem", examples=None):
        ref_ret = self.encode_to_hidden(examples)
        ref_ret = self.hidden_to_latent(ref_ret, is_sampling=True)
        if condition.startswith("sem"):
            ref_ret['sem_z'] = ref_ret['sem_mean']
        else:
            ref_ret['syn_z'] = ref_ret['syn_mean']

        if condition == "sem-only":
            sam_ref = self.sample_latent(batch_size=ref_ret['batch_size'])
            ref_ret['syn_z'] = sam_ref['syn_z']

        ret = self.latent_for_init(ret=ref_ret)

        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret
Ejemplo n.º 8
0
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )
Ejemplo n.º 9
0
class SyntaxGuideVAE(BaseVAE):
    def encode(self, input_var, length):
        pass

    def conditional_generating(self, condition=None, **kwargs):
        pass

    def generating(self, sample_num, batch_size=50):
        return super().generating(sample_num, batch_size)

    def unsupervised_generating(self, sample_num, batch_size=50):
        return super().unsupervised_generating(sample_num, batch_size)

    def predict(self, examples, to_word=True):
        return super().predict(examples, to_word)

    def base_information(self):
        return super().base_information()

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

    def syntax_encode(self, syntax_var, length):
        syntax_outputs, syntax_hidden = self.syntax_encoder.forward(
            syntax_var, length)
        return syntax_outputs, syntax_hidden

    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden

    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = ret['tgt_var']
        syntax_output = ret['syn_output']
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -self.decoder.score_decoding_results(
                tgt_token_scores, tgt_var)
        else:
            reconstruct_loss = -self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
            )

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }

    def get_loss(self, examples, step):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples)
        kl_loss, kl_weight = self.compute_kl_loss(explore['mean'],
                                                  explore['logv'], step)
        kl_weight *= self.args.kl_factor
        nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size']
        kl_loss = kl_loss / explore['batch_size']
        kl_item = kl_loss * kl_weight
        loss = kl_item + nll_loss
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'Model Score': nll_loss + kl_loss,
            'Loss': loss,
            'ELBO': loss,
            'KL Weight': kl_weight,
            'WD Drop': self.step_unk_rate,
            'KL Item': kl_item,
        }

    def batch_beam_decode(self, **kwargs):
        pass

    def latent_for_init(self, ret):
        z = ret['latent']
        batch_size = z.size(0)
        hidden = self.latent2hidden(z)

        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_factor,
                                 self.args.enc_hidden_dim)
            hidden = hidden.permute(1, 0, 2)
        else:
            hidden = hidden.unsqueeze(0)
        ret['decode_init'] = (hidden + ret['syn_hidden']) / 2
        return ret

    def sample_latent(self, batch_size):
        z = to_var(torch.randn([batch_size, self.latent_size]))
        return {"latent": z}

    def encode_to_hidden(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -len(examples[x].tgt))
        example_old_pos_map = [-1] * batch_size

        sorted_examples = [examples[i] for i in sorted_example_ids]

        syntax_word = [e.tgt for e in sorted_examples]
        syntax_var = to_input_variable(syntax_word,
                                       self.vocab.tgt,
                                       training=False,
                                       cuda=self.args.cuda,
                                       batch_first=True)
        length = [len(e.tgt) for e in sorted_examples]
        syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length)

        sent_words = [e.src for e in sorted_examples]
        sentence_hidden = self.sentence_encode(sent_words)
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        return {
            'hidden': sentence_hidden,
            "syn_output": syntax_output,
            "syn_hidden": syntax_hidden,
            'tgt_var': tgt_var,
            'old_pos': example_old_pos_map
        }

    def hidden_to_latent(self, ret, is_sampling):
        hidden = ret['hidden']
        batch_size = hidden.size(1)
        hidden = hidden.permute(1, 0, 2).contiguous()
        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size,
                                 self.args.enc_hidden_dim * self.hidden_factor)
        else:
            hidden = hidden.squeeze()
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        if is_sampling:
            std = torch.exp(0.5 * logv)
            z = to_var(torch.randn([batch_size, self.latent_size]))
            z = z * std + mean
        else:
            z = mean
        ret["latent"] = z
        ret["mean"] = mean
        ret['logv'] = logv
        return ret

    def decode_to_sentence(self, ret):
        sentence_decode_init = ret['decode_init']
        sentence_decode_init = self.bridger.forward(
            input_tensor=sentence_decode_init)

        decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward(
            inputs=None,
            encoder_outputs=None,
            encoder_hidden=sentence_decode_init,
        )

        result = torch.stack(ret_dict['sequence']).squeeze()
        temp_result = []
        if result.dim() < 2:
            result = result.unsqueeze(1)
        example_nums = result.size(1)
        for i in range(example_nums):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.src)
            seems = [[res], [len(res)]]
            temp_result.append(seems)

        final_result = [temp_result[i] for i in ret['old_pos']]
        return final_result