Ejemplo n.º 1
0
class BridgeRNN(nn.Module):
    def __init__(self,
                 args,
                 vocab,
                 enc_hidden_dim,
                 dec_hidden_dim,
                 embed,
                 mode='src'):
        super().__init__()

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=enc_hidden_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=dec_hidden_dim,
            decoder_layer=args.dec_num_layers,
        )
        if mode == 'src':
            self.decoder = RNNDecoder(
                vocab=len(vocab.src),
                max_len=args.src_max_time_step,
                input_size=args.dec_embed_dim,
                hidden_size=dec_hidden_dim,
                embed_droprate=args.dec_ed,
                rnn_droprate=args.dec_rd,
                n_layers=args.dec_num_layers,
                rnn_cell=args.rnn_type,
                use_attention=args.use_attention,
                embedding=embed,
                eos_id=vocab.src.eos_id,
                sos_id=vocab.src.sos_id,
            )
        else:
            self.decoder = RNNDecoder(
                vocab=len(vocab.tgt),
                max_len=args.tgt_max_time_step,
                input_size=args.dec_embed_dim,
                hidden_size=dec_hidden_dim,
                embed_droprate=args.dec_ed,
                rnn_droprate=args.dec_rd,
                n_layers=args.dec_num_layers,
                rnn_cell=args.rnn_type,
                use_attention=args.use_attention,
                embedding=embed,
                eos_id=vocab.tgt.eos_id,
                sos_id=vocab.tgt.sos_id,
            )

    def forward(self, hidden, tgt_var):
        decode_init = self.bridger.forward(input_tensor=hidden)

        _loss = -torch.sum(
            self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=None,
                encoder_hidden=decode_init,
            ))
        return _loss

    def predict(self, hidden):
        decode_init = self.bridger.forward(input_tensor=hidden)

        decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward(
            inputs=None,
            encoder_outputs=None,
            encoder_hidden=decode_init,
        )
        result = torch.stack(ret_dict['sequence']).squeeze()
        if result.dim() < 2:
            result = result.unsqueeze(1)
        return result
Ejemplo n.º 2
0
class BaseSeq2seq(nn.Module, BaseGenerator):
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(BaseSeq2seq, self).__init__()
        self.vocab = vocab
        self.src_vocab = vocab.src
        self.tgt_vocab = vocab.tgt
        self.args = args
        self.encoder = RNNEncoder(vocab_size=len(self.src_vocab),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.enc_factor = 2 if args.bidirectional else 1
        self.enc_dim = args.enc_hidden_dim * self.enc_factor

        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(self.tgt_vocab),
            max_len=args.tgt_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=tgt_embed,
            eos_id=self.tgt_vocab.eos_id,
            sos_id=self.tgt_vocab.sos_id,
        )

        self.beam_decoder = TopKDecoder(decoder_rnn=self.decoder,
                                        k=args.sample_size)
        print("enc layer: {}, dec layer: {}, type: {}, with attention: {}".
              format(args.enc_num_layers, args.dec_num_layers, args.rnn_type,
                     args.use_attention))

    def get_loss(self, **kwargs):
        return {"Loss": -self.score(**kwargs)}

    def forward(self, seqs_x, x_length, to_word=False):
        pass

    def init(self):
        self.encoder.rnn.flatten_parameters()
        self.decoder.rnn.flatten_parameters()

    def encode(self, src_var, src_length):
        encoder_outputs, encoder_hidden = self.encoder.forward(
            input_var=src_var, input_lengths=src_length)
        return encoder_outputs, encoder_hidden

    def bridge(self, encoder_hidden):
        # batch_size = encoder_hidden.size(1)
        # convert = encoder_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1)
        return self.bridger.forward(encoder_hidden)

    def get_hidden(self, examples):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        return encoder_hidden

    def score(self, examples, return_enc_state=False, **kwargs):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        scores = self.decoder.score(inputs=tgt_var,
                                    encoder_hidden=encoder_hidden,
                                    encoder_outputs=encoder_outputs)

        if return_enc_state:
            return scores, encoder_hidden
        else:
            return scores

    def predict(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=False,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)

        decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            teacher_forcing_ratio=0.0)
        result = torch.stack(ret_dict['sequence']).squeeze()
        final_result = []
        if len(result.size()) < 2:
            result = result.view(-1, 1)
        example_nums = result.size(-1)
        if to_word:
            for i in range(example_nums):
                hyp = result[:, i].data.tolist()
                res = id2word(hyp, self.vocab.tgt)
                seems = [[res], [len(res)]]
                final_result.append(seems)
        return final_result

    def beam_search(self, src_sent, beam_size=5, dmts=None):
        if dmts is None:
            dmts = self.args.decode_max_time_step
        src_var = to_input_variable(src_sent,
                                    self.src_vocab,
                                    cuda=self.args.cuda,
                                    training=False,
                                    append_boundary_sym=False,
                                    batch_first=True)
        src_length = [len(src_sent)]

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridger.forward(input_tensor=encoder_hidden)
        meta_data = self.beam_decoder.beam_search(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            beam_size=beam_size,
            decode_max_time_step=dmts)
        topk_sequence = meta_data['sequence']
        topk_score = meta_data['score'].squeeze()

        completed_hypotheses = torch.cat(topk_sequence, dim=-1)

        number_return = completed_hypotheses.size(0)
        final_result = []
        final_scores = []
        for i in range(number_return):
            hyp = completed_hypotheses[i, :].data.tolist()
            res = id2word(hyp, self.tgt_vocab)
            final_result.append(res)
            final_scores.append(topk_score[i].item())
        return final_result, final_scores

    def load_state_dict(self, state_dict, strict=True):
        return super().load_state_dict(state_dict, strict)

    def save(self, path):
        dir_name = os.path.dirname(path)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

        params = {
            'args': self.args,
            'vocab': self.vocab,
            'state_dict': self.state_dict(),
        }

        torch.save(params, path)

    @classmethod
    def load(cls, load_path):
        params = torch.load(load_path,
                            map_location=lambda storage, loc: storage)
        args = params['args']
        vocab = params['vocab']
        model = cls(args, vocab)
        model.load_state_dict(params['state_dict'])
        if args.cuda:
            model = model.cuda()
        return model
Ejemplo n.º 3
0
    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )
Ejemplo n.º 4
0
class RNNPredictor(nn.Module):
    def __init__(self, args, vocab, enc_hidden_dim, dec_hidden_dim, embed, mode='src'):
        super().__init__()
        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=enc_hidden_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=dec_hidden_dim,
            decoder_layer=args.dec_num_layers,
        )
        if mode == 'src':
            self.decoder = RNNDecoder(
                vocab=len(vocab.src),
                max_len=args.src_max_time_step,
                input_size=args.dec_embed_dim,
                hidden_size=dec_hidden_dim,
                embed_droprate=args.dec_ed,
                rnn_droprate=args.dec_rd,
                n_layers=args.dec_num_layers,
                rnn_cell=args.rnn_type,
                use_attention=args.use_attention,
                embedding=embed,
                eos_id=vocab.src.eos_id,
                sos_id=vocab.src.sos_id,
            )
            self.pad_id = vocab.src.pad_id
        else:
            self.decoder = RNNDecoder(
                vocab=len(vocab.tgt),
                max_len=args.tgt_max_time_step,
                input_size=args.dec_embed_dim,
                hidden_size=dec_hidden_dim,
                embed_droprate=args.dec_ed,
                rnn_droprate=args.dec_rd,
                n_layers=args.dec_num_layers,
                rnn_cell=args.rnn_type,
                use_attention=args.use_attention,
                embedding=embed,
                eos_id=vocab.tgt.eos_id,
                sos_id=vocab.tgt.sos_id,
            )
            self.pad_id = vocab.tgt.pad_id

    def forward(self, hidden, tgt_var, div_by_word=False):
        decode_init = self.bridger.forward(input_tensor=hidden)
        _loss = self.decoder.score(
            inputs=tgt_var,
            encoder_outputs=None,
            encoder_hidden=decode_init,
        )
        if div_by_word:
            tgt_len = tgt_var.ne(self.pad_id).sum(-1).float()
            _loss = _loss.div(tgt_len + 1e-9)
        return -_loss.sum()

    def detach(self):
        for p in self.parameters():
            p.requires_grad = False

    def reload(self):
        for p in self.parameters():
            p.requires_grad = True

    def predict(self, hidden):
        decode_init = self.bridger.forward(input_tensor=hidden)

        decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward(
            inputs=None,
            encoder_outputs=None,
            encoder_hidden=decode_init,
        )
        result = torch.stack(ret_dict['sequence']).squeeze()
        if result.dim() < 2:
            result = result.unsqueeze(1)
        return result
Ejemplo n.º 5
0
class SyntaxGuideVAE(BaseVAE):
    def encode(self, input_var, length):
        pass

    def conditional_generating(self, condition=None, **kwargs):
        pass

    def generating(self, sample_num, batch_size=50):
        return super().generating(sample_num, batch_size)

    def unsupervised_generating(self, sample_num, batch_size=50):
        return super().unsupervised_generating(sample_num, batch_size)

    def predict(self, examples, to_word=True):
        return super().predict(examples, to_word)

    def base_information(self):
        return super().base_information()

    def __init__(self, args, vocab, src_embed=None, tgt_embed=None):
        super(SyntaxGuideVAE, self).__init__(args,
                                             vocab,
                                             name='SyntaxGuideVAE')

        self.latent_size = args.latent_size
        self.rnn_type = args.rnn_type
        self.bidirectional = args.bidirectional
        self.num_layers = args.num_layers
        self.unk_rate = args.unk_rate
        self.step_unk_rate = 0.0

        self.encoder = RNNEncoder(vocab_size=len(vocab.src),
                                  max_len=args.src_max_time_step,
                                  input_size=args.enc_embed_dim,
                                  hidden_size=args.enc_hidden_dim,
                                  embed_droprate=args.enc_ed,
                                  rnn_droprate=args.enc_rd,
                                  n_layers=args.enc_num_layers,
                                  bidirectional=args.bidirectional,
                                  rnn_cell=args.rnn_type,
                                  variable_lengths=True,
                                  embedding=src_embed)

        self.syntax_encoder = RNNEncoder(vocab_size=len(vocab.tgt),
                                         max_len=args.tgt_max_time_step,
                                         input_size=args.enc_embed_dim,
                                         hidden_size=args.enc_hidden_dim,
                                         embed_droprate=args.enc_ed,
                                         rnn_droprate=args.enc_rd,
                                         n_layers=args.enc_num_layers,
                                         bidirectional=args.bidirectional,
                                         rnn_cell=args.rnn_type,
                                         variable_lengths=True,
                                         embedding=tgt_embed)

        self.hidden_size = args.enc_hidden_dim
        self.hidden_factor = (2 if args.bidirectional else
                              1) * args.enc_num_layers
        self.hidden2mean = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.hidden2logv = nn.Linear(self.hidden_size * self.hidden_factor,
                                     self.latent_size)
        self.latent2hidden = nn.Linear(self.latent_size,
                                       self.hidden_size * self.hidden_factor)

        self.enc_dim = args.enc_hidden_dim * (2 if args.bidirectional else 1)
        if args.mapper_type == "link":
            self.dec_hidden = self.enc_dim
        elif args.use_attention:
            self.dec_hidden = self.enc_dim
        else:
            self.dec_hidden = args.dec_hidden_dim

        self.bridger = MLPBridger(
            rnn_type=args.rnn_type,
            mapper_type=args.mapper_type,
            encoder_dim=self.enc_dim,
            encoder_layer=args.enc_num_layers,
            decoder_dim=self.dec_hidden,
            decoder_layer=args.dec_num_layers,
        )

        self.decoder = RNNDecoder(
            vocab=len(vocab.src),
            max_len=args.src_max_time_step,
            input_size=args.dec_embed_dim,
            hidden_size=self.dec_hidden,
            embed_droprate=args.dec_ed,
            rnn_droprate=args.dec_rd,
            n_layers=args.dec_num_layers,
            rnn_cell=args.rnn_type,
            use_attention=args.use_attention,
            embedding=src_embed,
            eos_id=vocab.src.eos_id,
            sos_id=vocab.src.sos_id,
        )

    def syntax_encode(self, syntax_var, length):
        syntax_outputs, syntax_hidden = self.syntax_encoder.forward(
            syntax_var, length)
        return syntax_outputs, syntax_hidden

    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden

    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = ret['tgt_var']
        syntax_output = ret['syn_output']
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -self.decoder.score_decoding_results(
                tgt_token_scores, tgt_var)
        else:
            reconstruct_loss = -self.decoder.score(
                inputs=tgt_var,
                encoder_outputs=syntax_output,
                encoder_hidden=decode_init,
            )

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }

    def get_loss(self, examples, step):
        self.step_unk_rate = wd_anneal_function(
            unk_max=self.unk_rate,
            anneal_function=self.args.unk_schedule,
            step=step,
            x0=self.args.x0,
            k=self.args.k)
        explore = self.forward(examples)
        kl_loss, kl_weight = self.compute_kl_loss(explore['mean'],
                                                  explore['logv'], step)
        kl_weight *= self.args.kl_factor
        nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size']
        kl_loss = kl_loss / explore['batch_size']
        kl_item = kl_loss * kl_weight
        loss = kl_item + nll_loss
        return {
            'KL Loss': kl_loss,
            'NLL Loss': nll_loss,
            'Model Score': nll_loss + kl_loss,
            'Loss': loss,
            'ELBO': loss,
            'KL Weight': kl_weight,
            'WD Drop': self.step_unk_rate,
            'KL Item': kl_item,
        }

    def batch_beam_decode(self, **kwargs):
        pass

    def latent_for_init(self, ret):
        z = ret['latent']
        batch_size = z.size(0)
        hidden = self.latent2hidden(z)

        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size, self.hidden_factor,
                                 self.args.enc_hidden_dim)
            hidden = hidden.permute(1, 0, 2)
        else:
            hidden = hidden.unsqueeze(0)
        ret['decode_init'] = (hidden + ret['syn_hidden']) / 2
        return ret

    def sample_latent(self, batch_size):
        z = to_var(torch.randn([batch_size, self.latent_size]))
        return {"latent": z}

    def encode_to_hidden(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -len(examples[x].tgt))
        example_old_pos_map = [-1] * batch_size

        sorted_examples = [examples[i] for i in sorted_example_ids]

        syntax_word = [e.tgt for e in sorted_examples]
        syntax_var = to_input_variable(syntax_word,
                                       self.vocab.tgt,
                                       training=False,
                                       cuda=self.args.cuda,
                                       batch_first=True)
        length = [len(e.tgt) for e in sorted_examples]
        syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length)

        sent_words = [e.src for e in sorted_examples]
        sentence_hidden = self.sentence_encode(sent_words)
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        return {
            'hidden': sentence_hidden,
            "syn_output": syntax_output,
            "syn_hidden": syntax_hidden,
            'tgt_var': tgt_var,
            'old_pos': example_old_pos_map
        }

    def hidden_to_latent(self, ret, is_sampling):
        hidden = ret['hidden']
        batch_size = hidden.size(1)
        hidden = hidden.permute(1, 0, 2).contiguous()
        if self.hidden_factor > 1:
            hidden = hidden.view(batch_size,
                                 self.args.enc_hidden_dim * self.hidden_factor)
        else:
            hidden = hidden.squeeze()
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        if is_sampling:
            std = torch.exp(0.5 * logv)
            z = to_var(torch.randn([batch_size, self.latent_size]))
            z = z * std + mean
        else:
            z = mean
        ret["latent"] = z
        ret["mean"] = mean
        ret['logv'] = logv
        return ret

    def decode_to_sentence(self, ret):
        sentence_decode_init = ret['decode_init']
        sentence_decode_init = self.bridger.forward(
            input_tensor=sentence_decode_init)

        decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decoder.forward(
            inputs=None,
            encoder_outputs=None,
            encoder_hidden=sentence_decode_init,
        )

        result = torch.stack(ret_dict['sequence']).squeeze()
        temp_result = []
        if result.dim() < 2:
            result = result.unsqueeze(1)
        example_nums = result.size(1)
        for i in range(example_nums):
            hyp = result[:, i].data.tolist()
            res = id2word(hyp, self.vocab.src)
            seems = [[res], [len(res)]]
            temp_result.append(seems)

        final_result = [temp_result[i] for i in ret['old_pos']]
        return final_result