Esempio n. 1
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    opt.input_size = 1

    if opt.encoder_type == "transformer":
        encoder = TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                     opt.heads, opt.transformer_ff,
                                     opt.dropout, opt.input_size, embeddings)

    elif opt.encoder_type == "ctransformer":
        encoder = CTransformerEncoder([2, 2, 2, 2], opt.enc_layers,
                                      opt.enc_rnn_size, opt.heads,
                                      opt.transformer_ff, opt.dropout,
                                      embeddings)
    elif opt.encoder_type == "cnn":
        encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                             opt.cnn_kernel_width, opt.dropout, opt.input_size,
                             embeddings)
    elif opt.encoder_type == "mean":
        encoder = MeanEncoder(opt.enc_layers, opt.input_size, embeddings)
    elif opt.encoder_type == "nano":
        encoder = NanoEncoder(opt.rnn_type, opt.enc_layers, opt.dec_layers,
                              opt.enc_rnn_size, opt.dec_rnn_size,
                              opt.audio_enc_pooling, opt.dropout,
                              opt.sample_rate, opt.window_size, opt.input_size)
    elif opt.encoder_type == "crnn":
        encoder = CRNNEncoder([2, 2, 2, 2], opt.rnn_type, opt.enc_layers,
                              opt.dec_layers, opt.enc_rnn_size,
                              opt.dec_rnn_size, opt.audio_enc_pooling,
                              opt.dropout, opt.sample_rate, opt.window_size)
    elif opt.encoder_type == "resnet":

        if opt.decoder_type == 'cnn':
            encoder = ResNetEncoder(opt.enc_layers, opt.enc_rnn_size,
                                    [2, 2, 2, 2], opt.input_size, embeddings)
        else:
            encoder = ResNetForRNNEncoder(opt.dec_layers, opt.enc_rnn_size,
                                          opt.dec_rnn_size, [2, 2, 2, 2],
                                          opt.input_size, embeddings,
                                          opt.rnn_type)
    else:
        encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                             opt.enc_rnn_size, opt.dropout, opt.input_size,
                             embeddings, opt.bridge)
    return encoder
Esempio n. 2
0
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = Bridge(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )
Esempio n. 3
0
class SyntaxGuideVAE(BaseVAE):
    def encode(self, input_var, length):
        pass

    def conditional_generating(self, condition=None, **kwargs):
        pass

    def generating(self, sample_num, batch_size=50):
        return super().generating(sample_num, batch_size)

    def unsupervised_generating(self, sample_num, batch_size=50):
        return super().unsupervised_generating(sample_num, batch_size)

    def predict(self, examples, to_word=True):
        return super().predict(examples, to_word)

    def base_information(self):
        return super().base_information()

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = Bridge(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

    def syntax_encode(self, syntax_var, length):
        syntax_outputs, syntax_hidden = self.syntax_encoder.forward(
            syntax_var, length)
        return syntax_outputs, syntax_hidden

    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden

    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = ret['tgt_var']
        syntax_output = ret['syn_output']
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -self.decoder.score_decoding_results(
                tgt_token_scores, tgt_var)
        else:
            reconstruct_loss = -self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
            )

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }

    def get_loss(self, examples, step):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples)
        kl_loss, kl_weight = self.compute_kl_loss(explore['mean'],
                                                  explore['logv'], step)
        kl_weight *= self.args.kl_factor
        nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size']
        kl_loss = kl_loss / explore['batch_size']
        kl_item = kl_loss * kl_weight
        loss = kl_item + nll_loss
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'Model Score': nll_loss + kl_loss,
            'Loss': loss,
            'ELBO': loss,
            'KL Weight': kl_weight,
            'WD Drop': self.step_unk_rate,
            'KL Item': kl_item,
        }

    def batch_beam_decode(self, **kwargs):
        pass

    def latent_for_init(self, ret):
        z = ret['latent']
        batch_size = z.size(0)
        hidden = self.latent2hidden(z)

        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_factor,
                                 self.args.enc_hidden_dim)
            hidden = hidden.permute(1, 0, 2)
        else:
            hidden = hidden.unsqueeze(0)
        ret['decode_init'] = (hidden + ret['syn_hidden']) / 2
        return ret

    def sample_latent(self, batch_size):
        z = to_var(torch.randn([batch_size, self.latent_size]))
        return {"latent": z}

    def encode_to_hidden(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -len(examples[x].tgt))
        example_old_pos_map = [-1] * batch_size

        sorted_examples = [examples[i] for i in sorted_example_ids]

        syntax_word = [e.tgt for e in sorted_examples]
        syntax_var = to_input_variable(syntax_word,
                                       self.vocab.tgt,
                                       training=False,
                                       cuda=self.args.cuda,
                                       batch_first=True)
        length = [len(e.tgt) for e in sorted_examples]
        syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length)

        sent_words = [e.src for e in sorted_examples]
        sentence_hidden = self.sentence_encode(sent_words)
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        return {
            'hidden': sentence_hidden,
            "syn_output": syntax_output,
            "syn_hidden": syntax_hidden,
            'tgt_var': tgt_var,
            'old_pos': example_old_pos_map
        }

    def hidden_to_latent(self, ret, is_sampling):
        hidden = ret['hidden']
        batch_size = hidden.size(1)
        hidden = hidden.permute(1, 0, 2).contiguous()
        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size,
                                 self.args.enc_hidden_dim * self.hidden_factor)
        else:
            hidden = hidden.squeeze()
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        if is_sampling:
            std = torch.exp(0.5 * logv)
            z = to_var(torch.randn([batch_size, self.latent_size]))
            z = z * std + mean
        else:
            z = mean
        ret["latent"] = z
        ret["mean"] = mean
        ret['logv'] = logv
        return ret

    def decode_to_sentence(self, ret):
        sentence_decode_init = ret['decode_init']
        sentence_decode_init = self.bridger.forward(
            input_tensor=sentence_decode_init)

        decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward(
            inputs=None,
            encoder_outputs=None,
            encoder_hidden=sentence_decode_init,
        )

        result = torch.stack(ret_dict['sequence']).squeeze()
        temp_result = []
        if result.dim() < 2:
            result = result.unsqueeze(1)
        example_nums = result.size(1)
        for i in range(example_nums):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.src)
            seems = [[res], [len(res)]]
            temp_result.append(seems)

        final_result = [temp_result[i] for i in ret['old_pos']]
        return final_result
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(DisentangleVAE,
              self).__init__(args,
                             vocab,
                             name="Disentangle VAE with deep encoder")
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if src_embed is None:
            self.src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            self.src_embed = src_embed
        if tgt_embed is None:
            self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size)
        else:
            self.tgt_embed = tgt_embed

        self.pad_idx = vocab.src.sos_id

        self.latent_size = int(args.latent_size)
        self.rnn_type = args.rnn_type
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.direction_num = 2 if args.bidirectional else 1

        self.enc_hidden_dim = args.enc_hidden_dim
        self.enc_layer_dim = args.enc_hidden_dim * self.direction_num
        self.enc_hidden_factor = self.direction_num * args.enc_num_layers
        self.dec_hidden_factor = args.dec_num_layers
        args.use_attention = False
        if args.mapper_type == "link":
            self.dec_layer_dim = self.enc_layer_dim
        elif args.use_attention:
            self.dec_layer_dim = self.enc_layer_dim
        else:
            self.dec_layer_dim = args.dec_hidden_dim

        syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)
        sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)

        task_enc_dim = int(self.enc_layer_dim / 2)
        task_dec_dim = int(self.dec_layer_dim / 2)

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=self.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=self.src_embed)
        # output: [layer*direction ,batch_size, enc_hidden_dim]

        pack_decoder = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=self.enc_layer_dim,
            dec_hidden_dim=self.dec_layer_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
        self.bridger = pack_decoder.bridger
        self.decoder = pack_decoder.decoder

        if "report" in self.args:
            syn_common = nn.Sequential(
                nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.syn_mean = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.syn_logv = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))

            sem_common = nn.Sequential(
                nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.sem_mean = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.sem_logv = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
        else:
            self.syn_mean = nn.Linear(syn_var_dim, self.latent_size)
            self.syn_logv = nn.Linear(syn_var_dim, self.latent_size)
            self.sem_mean = nn.Linear(sem_var_dim, self.latent_size)
            self.sem_logv = nn.Linear(sem_var_dim, self.latent_size)

        self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim)
        self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim)

        self.sup_syn = BridgeRNN(args,
                                 vocab,
                                 enc_hidden_dim=task_enc_dim,
                                 dec_hidden_dim=task_dec_dim,
                                 embed=tgt_embed,
                                 mode='tgt')

        self.sup_sem = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.syn_adv = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.tgt_embed if args.share_embed else None,
            mode='tgt')

        self.syn_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

        self.sem_adv = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
class DisentangleVAE(BaseVAE):
    """
    Encoder the sentence, predict the parser,
    """
    def decode(self, inputs, encoder_outputs, encoder_hidden):
        return self.decoder.forward(inputs=inputs,
                                    encoder_outputs=encoder_outputs,
                                    encoder_hidden=encoder_hidden)

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(DisentangleVAE,
              self).__init__(args,
                             vocab,
                             name="Disentangle VAE with deep encoder")
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if src_embed is None:
            self.src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            self.src_embed = src_embed
        if tgt_embed is None:
            self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size)
        else:
            self.tgt_embed = tgt_embed

        self.pad_idx = vocab.src.sos_id

        self.latent_size = int(args.latent_size)
        self.rnn_type = args.rnn_type
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.direction_num = 2 if args.bidirectional else 1

        self.enc_hidden_dim = args.enc_hidden_dim
        self.enc_layer_dim = args.enc_hidden_dim * self.direction_num
        self.enc_hidden_factor = self.direction_num * args.enc_num_layers
        self.dec_hidden_factor = args.dec_num_layers
        args.use_attention = False
        if args.mapper_type == "link":
            self.dec_layer_dim = self.enc_layer_dim
        elif args.use_attention:
            self.dec_layer_dim = self.enc_layer_dim
        else:
            self.dec_layer_dim = args.dec_hidden_dim

        syn_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)
        sem_var_dim = int(self.enc_hidden_dim * self.enc_hidden_factor / 2)

        task_enc_dim = int(self.enc_layer_dim / 2)
        task_dec_dim = int(self.dec_layer_dim / 2)

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=self.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=self.src_embed)
        # output: [layer*direction ,batch_size, enc_hidden_dim]

        pack_decoder = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=self.enc_layer_dim,
            dec_hidden_dim=self.dec_layer_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')
        self.bridger = pack_decoder.bridger
        self.decoder = pack_decoder.decoder

        if "report" in self.args:
            syn_common = nn.Sequential(
                nn.Linear(syn_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.syn_mean = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.syn_logv = nn.Sequential(
                syn_common, nn.Linear(self.latent_size * 2, self.latent_size))

            sem_common = nn.Sequential(
                nn.Linear(sem_var_dim, self.latent_size * 2, True), nn.ReLU())
            self.sem_mean = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
            self.sem_logv = nn.Sequential(
                sem_common, nn.Linear(self.latent_size * 2, self.latent_size))
        else:
            self.syn_mean = nn.Linear(syn_var_dim, self.latent_size)
            self.syn_logv = nn.Linear(syn_var_dim, self.latent_size)
            self.sem_mean = nn.Linear(sem_var_dim, self.latent_size)
            self.sem_logv = nn.Linear(sem_var_dim, self.latent_size)

        self.syn_to_h = nn.Linear(self.latent_size, syn_var_dim)
        self.sem_to_h = nn.Linear(self.latent_size, sem_var_dim)

        self.sup_syn = BridgeRNN(args,
                                 vocab,
                                 enc_hidden_dim=task_enc_dim,
                                 dec_hidden_dim=task_dec_dim,
                                 embed=tgt_embed,
                                 mode='tgt')

        self.sup_sem = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.syn_adv = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.tgt_embed if args.share_embed else None,
            mode='tgt')

        self.syn_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

        self.sem_adv = BridgeMLP(
            args=args,
            vocab=vocab,
            enc_dim=task_enc_dim,
            dec_hidden=task_dec_dim,
        )

        self.sem_infer = BridgeRNN(
            args,
            vocab,
            enc_hidden_dim=task_enc_dim,
            dec_hidden_dim=task_dec_dim,
            embed=self.src_embed if args.share_embed else None,
            mode='src')

    def base_information(self):
        origin = super().base_information()
        return origin \
               + "mul_syn:{}\n" \
                 "mul_sen:{}\n" \
                 "adv_syn:{}\n" \
                 "adv_sem:{}\n" \
                 "inf_syn:{}\n" \
                 "inf_sem:{}\n" \
                 "kl_syn:{}\n" \
                 "kl_sem:{}\n".format(str(self.args.mul_syn),
                                      str(self.args.mul_sem),
                                      str(self.args.adv_syn),
                                      str(self.args.adv_sem),
                                      str(self.args.inf_syn * self.args.infer_weight),
                                      str(self.args.inf_sem * self.args.infer_weight),
                                      str(self.args.syn_weight),
                                      str(self.args.sem_weight)
                                      )

    def get_gpu(self):
        model_list = [
            self.encoder, self.bridger, self.decoder, self.syn_mean,
            self.syn_logv, self.syn_to_h, self.sem_mean, self.sem_logv,
            self.sem_to_h, self.sup_syn, self.sup_sem, self.syn_adv,
            self.syn_infer, self.sem_adv, self.sem_infer
        ]
        for model in model_list:
            device = torch.device(
                "cuda:0" if torch.cuda.is_available else "cpu")
            model = torch.nn.DataParallel(model)
            model.to(device)

    def encode(self, input_var, length):
        if self.training and self.args.src_wd > 0.:
            input_var = unk_replace(input_var, self.step_unk_rate,
                                    self.vocab.src)

        encoder_output, encoder_hidden = self.encoder.forward(
            input_var, length)
        return encoder_output, encoder_hidden

    def forward(self, examples, is_dis=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)

        words = [e.src for e in examples]
        tgt_var = to_input_variable(words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(syn_seqs,
                                    self.vocab.tgt,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syntax_hidden=syn_hidden,
                semantic_hidden=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var)
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_log_score = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(inputs=tgt_var,
                                   encoder_outputs=None,
                                   encoder_hidden=sentence_decode_init))

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss

        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret

    def get_loss(self, examples, step, is_dis=False):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples, is_dis)

        if is_dis:
            return explore

        sem_kl, kl_weight = self.compute_kl_loss(
            mean=explore['sem_mean'],
            logv=explore['sem_logv'],
            step=step,
        )
        syn_kl, _ = self.compute_kl_loss(
            mean=explore['syn_mean'],
            logv=explore['syn_logv'],
            step=step,
        )

        batch_size = explore['batch_size']
        kl_weight *= self.args.kl_factor
        kl_loss = (self.args.sem_weight * sem_kl + self.args.syn_weight *
                   syn_kl) / (self.args.sem_weight + self.args.syn_weight)
        kl_loss /= batch_size
        mul_loss = explore['mul'] / batch_size
        adv_loss = explore['adv'] / batch_size
        nll_loss = explore['nll_loss'] / batch_size
        kl_item = kl_loss * kl_weight

        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'MUL Loss': mul_loss,
            'ADV Loss': adv_loss,
            'KL Weight': kl_weight,
            'KL Item': kl_item,
            'Model Score': kl_loss + nll_loss,
            'ELBO': kl_item + nll_loss,
            'Loss': kl_item + nll_loss + mul_loss - adv_loss,
            'SYN KL Loss': syn_kl / explore['batch_size'],
            'SEM KL Loss': sem_kl / explore['batch_size'],
        }

    def get_adv_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        if self.training:
            with torch.no_grad():
                loss_dict = self._dis_loss(syntax_hidden, semantic_hidden,
                                           syn_tgt, sem_tgt)
            if self.args.infer_weight > 0.:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \
                          loss_dict['adv_sem_inf']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \
                          loss_dict['adv_syn_inf']
            else:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup']
            return adv_syn, adv_sem
        else:
            loss_dict = self._dis_loss(syntax_hidden, semantic_hidden, syn_tgt,
                                       sem_tgt)
            if self.args.infer_weight > 0.:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup'] + self.args.infer_weight * self.args.inf_sem * \
                          loss_dict['adv_sem_inf']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup'] + self.args.infer_weight * self.args.inf_syn * \
                          loss_dict['adv_syn_inf']
            else:
                adv_syn = self.args.adv_syn * loss_dict['adv_syn_sup']
                adv_sem = self.args.adv_sem * loss_dict['adv_sem_sup']

            return adv_syn, adv_sem

    def get_dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        syntax_hid = syntax_hidden.detach()
        semantic_hid = semantic_hidden.detach()

        loss_dict = self._dis_loss(syntax_hid, semantic_hid, syn_tgt, sem_tgt)
        if self.args.infer_weight > 0.:
            return loss_dict['adv_syn_sup'] + loss_dict[
                'adv_sem_inf'], loss_dict['adv_sem_sup'] + loss_dict[
                    'adv_syn_inf']
        else:
            return loss_dict['adv_syn_sup'], loss_dict['adv_sem_sup']

    def _dis_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        dis_syn_sup = self.syn_adv.forward(hidden=semantic_hidden,
                                           tgt_var=syn_tgt)
        dis_sem_sup = self.sem_adv.forward(hidden=syntax_hidden,
                                           tgt_var=sem_tgt)
        if self.args.infer_weight > 0.:

            dis_syn_inf = self.syn_infer.forward(hidden=syntax_hidden,
                                                 tgt_var=sem_tgt)
            dis_sem_inf = self.sem_infer.forward(hidden=semantic_hidden,
                                                 tgt_var=sem_tgt)
            return {
                'adv_syn_sup': dis_syn_sup if self.args.adv_syn > 0. else 0.,
                'adv_sem_sup': dis_sem_sup if self.args.adv_sem > 0. else 0.,
                'adv_syn_inf': dis_syn_inf if self.args.inf_syn > 0. else 0.,
                "adv_sem_inf": dis_sem_inf if self.args.inf_sem > 0. else 0.
            }
        else:
            return {
                'adv_syn_sup': dis_syn_sup,
                'adv_sem_sup': dis_sem_sup,
            }

    def get_mul_loss(self, syntax_hidden, semantic_hidden, syn_tgt, sem_tgt):
        syn_loss = self.sup_syn.forward(hidden=syntax_hidden, tgt_var=syn_tgt)
        sem_loss = self.sup_sem.forward(hidden=semantic_hidden,
                                        tgt_var=sem_tgt)
        return self.args.mul_syn * syn_loss, self.args.mul_sem * sem_loss

    def sample_latent(self, batch_size):
        syntax_latent = to_var(torch.randn([batch_size, self.latent_size]))
        semantic_latent = to_var(torch.randn([batch_size, self.latent_size]))
        return {
            "syn_z": syntax_latent,
            "sem_z": semantic_latent,
        }

    def hidden_to_latent(self, ret, is_sampling=True):
        hidden = ret['hidden']

        def sampling(mean, logv):
            if is_sampling:
                std = torch.exp(0.5 * logv)
                z = to_var(torch.randn([batch_size, self.latent_size]))
                z = z * std + mean
            else:
                z = mean
            return z

        def split_hidden(encode_hidden):
            bs = encode_hidden.size(1)
            factor = encode_hidden.size(0)
            hid = encode_hidden.permute(1, 0, 2).contiguous().view(
                bs, factor, 2, -1)
            return hid[:, :, 0, :].contiguous().view(
                bs, -1), hid[:, :, 1, :].contiguous().view(bs, -1)

        batch_size = hidden.size(1)
        sem_hid, syn_hid = split_hidden(hidden)

        semantic_mean = self.sem_mean(sem_hid)
        semantic_logv = self.sem_logv(sem_hid)
        syntax_mean = self.syn_mean(syn_hid)
        syntax_logv = self.syn_logv(syn_hid)
        syntax_latent = sampling(syntax_mean, syntax_logv)
        semantic_latent = sampling(semantic_mean, semantic_logv)

        ret['syn_mean'] = syntax_mean
        ret['syn_logv'] = syntax_logv
        ret['sem_mean'] = semantic_mean
        ret['sem_logv'] = semantic_logv
        ret['syn_z'] = syntax_latent
        ret['sem_z'] = semantic_latent

        return ret

    def latent_for_init(self, ret):
        def reshape(xx_hidden):
            xx_hidden = xx_hidden.view(batch_size, self.enc_hidden_factor,
                                       self.enc_hidden_dim / 2)
            xx_hidden = xx_hidden.permute(1, 0, 2)
            return xx_hidden

        syntax_latent = ret['syn_z']
        semantic_latent = ret['sem_z']
        batch_size = semantic_latent.size(0)
        syntax_hidden = reshape(self.syn_to_h(syntax_latent))
        semantic_hidden = reshape(self.sem_to_h(semantic_latent))

        ret['syn_hidden'] = syntax_hidden
        ret['sem_hidden'] = semantic_hidden
        ret['decode_init'] = torch.cat([syntax_hidden, semantic_hidden],
                                       dim=-1)
        return ret

    def evaluate_(self, examples, beam_size=5):
        if not isinstance(examples, list):
            examples = [examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret

    def predict_syntax(self, hidden, predictor):
        result = predictor.predict(hidden)
        numbers = result.size(1)
        final_result = []
        for i in range(numbers):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.tgt)
            seems = [[res], [len(res)]]
            final_result.append(seems)
        return final_result

    def extract_variable(self, examples):
        pass

    def eval_syntax(self, examples):
        ret = self.encode_to_hidden(examples, need_sort=True)
        ret = self.hidden_to_latent(ret, is_sampling=False)
        ret = self.latent_for_init(ret)
        return self.predict_syntax(hidden=ret['syn_hidden'],
                                   predictor=self.sup_syn)

    def eval_adv(self, sem_in, syn_ref):
        sem_ret = self.encode_to_hidden(sem_in)
        sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training)
        syn_ret = self.encode_to_hidden(syn_ref, need_sort=True)
        syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training)
        sem_ret = self.latent_for_init(ret=sem_ret)
        syn_ret = self.latent_for_init(ret=syn_ret)
        ret = dict(sem_z=sem_ret['sem_z'], syn_z=syn_ret['syn_z'])
        ret = self.latent_for_init(ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        ret['ori syn'] = self.predict_syntax(hidden=sem_ret['syn_hidden'],
                                             predictor=self.sup_syn)
        ret['ref syn'] = self.predict_syntax(hidden=syn_ret['syn_hidden'],
                                             predictor=self.sup_syn)
        return ret

    def conditional_generating(self, condition="sem", examples=None):
        ref_ret = self.encode_to_hidden(examples)
        ref_ret = self.hidden_to_latent(ref_ret, is_sampling=True)
        if condition.startswith("sem"):
            ref_ret['sem_z'] = ref_ret['sem_mean']
        else:
            ref_ret['syn_z'] = ref_ret['syn_mean']

        if condition == "sem-only":
            sam_ref = self.sample_latent(batch_size=ref_ret['batch_size'])
            ref_ret['syn_z'] = sam_ref['syn_z']

        ret = self.latent_for_init(ret=ref_ret)

        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret
Esempio n. 6
0
    def __init__(self, args, vocab, word_embed=None):
        super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE')
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if word_embed is None:
            src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            src_embed = word_embed

        if args.share_embed:
            tgt_embed = src_embed
            args.dec_embed_dim = args.enc_embed_dim
        else:
            tgt_embed = None

        self.latent_size = args.latent_size
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        args.use_attention = False

        # layer size setting
        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1
                                              )  # single layer unit size
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.bridger = Bridge(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=False,
            embedding=tgt_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

        self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor,
                                     args.latent_size)
        self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor,
                                     args.latent_size)
        self.latent2hidden = nn.Linear(args.latent_size,
                                       args.hidden_size * self.hidden_factor)
Esempio n. 7
0
class VanillaVAE(BaseVAE):
    def __init__(self, args, vocab, word_embed=None):
        super(VanillaVAE, self).__init__(args, vocab, name='MySentVAE')
        print("This is {} with parameter\n{}".format(self.name,
                                                     self.base_information()))
        if word_embed is None:
            src_embed = nn.Embedding(len(vocab.src), args.embed_size)
        else:
            src_embed = word_embed

        if args.share_embed:
            tgt_embed = src_embed
            args.dec_embed_dim = args.enc_embed_dim
        else:
            tgt_embed = None

        self.latent_size = args.latent_size
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0
        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        args.use_attention = False

        # layer size setting
        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1
                                              )  # single layer unit size
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.bridger = Bridge(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=False,
            embedding=tgt_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

        self.hidden2mean = nn.Linear(args.hidden_size * self.hidden_factor,
                                     args.latent_size)
        self.hidden2logv = nn.Linear(args.hidden_size * self.hidden_factor,
                                     args.latent_size)
        self.latent2hidden = nn.Linear(args.latent_size,
                                       args.hidden_size * self.hidden_factor)

    def encode(self, input_var, length):
        if self.training and self.args.src_wd:
            input_var = unk_replace(input_var, self.step_unk_rate,
                                    self.vocab.src)
        encoder_output, encoder_hidden = self.encoder.forward(
            input_var, length)
        return encoder_output, encoder_hidden

    def decode(self, inputs, encoder_outputs, encoder_hidden):
        return self.decoder.forward(inputs=inputs,
                                    encoder_outputs=encoder_outputs,
                                    encoder_hidden=encoder_hidden)

    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sent_words = [e.src for e in examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd > 0.:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_token_scores, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(
                    inputs=tgt_var,
                    encoder_outputs=None,
                    encoder_hidden=decode_init,
                ))

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }

    def get_loss(self, examples, step):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples)
        batch_size = explore['batch_size']
        kl_loss, kl_weight = self.compute_kl_loss(explore['mean'],
                                                  explore['logv'], step)
        kl_weight *= self.args.kl_factor
        nll_loss = explore['nll_loss'] / batch_size
        kl_loss = kl_loss / batch_size
        kl_item = kl_loss * kl_weight
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'KL Weight': kl_weight,
            'Model Score': kl_loss + nll_loss,
            'ELBO': kl_item + nll_loss,
            'Loss': kl_item + nll_loss,
            'KL Item': kl_item,
        }

    def sample_latent(self, batch_size):
        z = to_var(torch.randn([batch_size, self.latent_size]))
        return {"latent": z}

    def latent_for_init(self, ret):
        z = ret['latent']
        batch_size = z.size(0)
        hidden = self.latent2hidden(z)

        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_factor,
                                 self.hidden_size)
            hidden = hidden.permute(1, 0, 2)
        else:
            hidden = hidden.unsqueeze(0)
        ret['decode_init'] = hidden
        return ret

    def batch_beam_decode(self, examples):
        raise NotImplementedError

    def hidden_to_latent(self, ret, is_sampling=True):
        hidden = ret['hidden']
        batch_size = hidden.size(1)
        hidden = hidden.permute(1, 0, 2).contiguous()
        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size,
                                 self.hidden_size * self.hidden_factor)
        else:
            hidden = hidden.squeeze()
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        if is_sampling:
            std = torch.exp(0.5 * logv)
            z = to_var(torch.randn([batch_size, self.latent_size]))
            z = z * std + mean
        else:
            z = mean
        ret["latent"] = z
        ret["mean"] = mean
        ret['logv'] = logv
        return ret

    def conditional_generating(self, condition='sem', examples=None):
        if not isinstance(examples, list):
            examples = [examples]
        if condition.startswith("sem"):
            ret = self.encode_to_hidden(examples)
            ret = self.hidden_to_latent(ret=ret, is_sampling=True)
            ret = self.latent_for_init(ret=ret)
            return {'res': self.decode_to_sentence(ret=ret)}
        if condition is None:
            return {
                "res": self.unsupervised_generating(sample_num=len(examples))
            }

    def eval_adv(self, sem_in, syn_ref):
        sem_ret = self.encode_to_hidden(sem_in)
        sem_ret = self.hidden_to_latent(sem_ret, is_sampling=self.training)
        syn_ret = self.encode_to_hidden(syn_ref, need_sort=True)
        syn_ret = self.hidden_to_latent(syn_ret, is_sampling=self.training)
        sem_ret = self.latent_for_init(ret=sem_ret)
        syn_ret = self.latent_for_init(ret=syn_ret)
        ret = dict()
        ret["latent"] = (sem_ret['latent'] + syn_ret['latent']) * 0.5
        ret = self.latent_for_init(ret)
        ret['res'] = self.decode_to_sentence(ret=ret)
        return ret