コード例 #1
0
    def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout,
                 enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb,
                 decoder_emb, pad_id):
        super(VAE, self).__init__()
        assert encoder_emb.num_embeddings == decoder_emb.num_embeddings
        assert encoder_emb.embedding_dim == decoder_emb.embedding_dim
        self.voc_size = encoder_emb.num_embeddings
        self.emb_dim = encoder_emb.embedding_dim
        self.hid_dim = hid_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.dropout = dropout
        self.enc_bi = enc_bi
        self.n_dir = 2 if self.enc_bi else 1
        self.dec_max_len = dec_max_len
        self.beam_size = beam_size
        self.WEAtt_type = WEAtt_type
        self.latent_dim = latent_dim

        self.Encoder = Encoder(emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim,
                               n_layer=self.enc_layers,
                               dropout=self.dropout,
                               bi=self.enc_bi,
                               embedding=encoder_emb)
        self.PriorGaussian = torch.distributions.Normal(
            gpu_wrapper(torch.zeros(self.latent_dim)),
            gpu_wrapper(torch.ones(self.latent_dim)))
        self.PosteriorGaussian = Gaussian(in_dim=self.hid_dim * self.n_dir *
                                          self.enc_layers,
                                          out_dim=self.latent_dim)
        self.Decoder = Decoder(voc_size=self.voc_size,
                               latent_dim=self.latent_dim,
                               emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim * self.n_dir,
                               n_layer=self.dec_layers,
                               dropout=self.dropout,
                               max_len=self.dec_max_len,
                               beam_size=self.beam_size,
                               WEAtt_type=self.WEAtt_type,
                               embedding=decoder_emb)
        self.BoW = nn.Linear(self.latent_dim, self.voc_size)

        self.criterionSeq = SeqLoss(voc_size=self.voc_size,
                                    pad=pad_id,
                                    end=None,
                                    unk=None)
コード例 #2
0
    def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout,
                 enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb,
                 decoder_emb, pad_id):
        super(SEQ2SEQ, self).__init__()
        assert encoder_emb.num_embeddings == decoder_emb.num_embeddings
        assert encoder_emb.embedding_dim == decoder_emb.embedding_dim
        self.voc_size = encoder_emb.num_embeddings
        self.emb_dim = encoder_emb.embedding_dim
        self.hid_dim = hid_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.dropout = dropout
        self.enc_bi = enc_bi
        self.n_dir = 2 if self.enc_bi else 1
        self.dec_max_len = dec_max_len
        self.beam_size = beam_size
        self.WEAtt_type = WEAtt_type
        self.latent_dim = latent_dim

        self.PostEncoder = Encoder(emb_dim=self.emb_dim,
                                   hid_dim=self.hid_dim,
                                   n_layer=self.enc_layers,
                                   dropout=self.dropout,
                                   bi=self.enc_bi,
                                   embedding=encoder_emb)
        self.PostRepr = nn.Linear(self.hid_dim * self.n_dir * self.enc_layers,
                                  self.emb_dim)
        self.Decoder = Decoder(voc_size=self.voc_size,
                               latent_dim=self.latent_dim,
                               emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim * self.n_dir,
                               n_layer=self.dec_layers,
                               dropout=self.dropout,
                               max_len=self.dec_max_len,
                               beam_size=self.beam_size,
                               WEAtt_type=self.WEAtt_type,
                               embedding=decoder_emb)

        self.criterionSeq = SeqLoss(voc_size=self.voc_size,
                                    pad=pad_id,
                                    end=None,
                                    unk=None)
コード例 #3
0
    def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id):
        super(S_VAE, self).__init__()
        assert encoder_emb.num_embeddings == decoder_emb.num_embeddings
        assert encoder_emb.embedding_dim == decoder_emb.embedding_dim
        self.voc_size = encoder_emb.num_embeddings
        self.emb_dim = encoder_emb.embedding_dim
        self.hid_dim = hid_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.dropout = dropout
        self.enc_bi = enc_bi
        self.n_dir = 2 if self.enc_bi else 1
        self.dec_max_len = dec_max_len
        self.beam_size = beam_size
        self.WEAtt_type = WEAtt_type
        self.latent_dim = latent_dim

        self.Encoder = Encoder(emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim,
                               n_layer=self.enc_layers,
                               dropout=self.dropout,
                               bi=self.enc_bi,
                               embedding=encoder_emb)
        self.PriorUniform = HypersphericalUniform(dim=self.latent_dim)
        self.PosteriorVMF = VonMisesFisherModule(in_dim=self.hid_dim * self.n_dir * self.enc_layers, out_dim=self.latent_dim)
        self.Decoder = Decoder(voc_size=self.voc_size,
                               latent_dim=self.latent_dim,
                               emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim * self.n_dir,
                               n_layer=self.dec_layers,
                               dropout=self.dropout,
                               max_len=self.dec_max_len,
                               beam_size=self.beam_size,
                               WEAtt_type=self.WEAtt_type,
                               embedding=decoder_emb)

        self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None)
コード例 #4
0
class VAE(nn.Module):
    def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout,
                 enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb,
                 decoder_emb, pad_id):
        super(VAE, self).__init__()
        assert encoder_emb.num_embeddings == decoder_emb.num_embeddings
        assert encoder_emb.embedding_dim == decoder_emb.embedding_dim
        self.voc_size = encoder_emb.num_embeddings
        self.emb_dim = encoder_emb.embedding_dim
        self.hid_dim = hid_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.dropout = dropout
        self.enc_bi = enc_bi
        self.n_dir = 2 if self.enc_bi else 1
        self.dec_max_len = dec_max_len
        self.beam_size = beam_size
        self.WEAtt_type = WEAtt_type
        self.latent_dim = latent_dim

        self.Encoder = Encoder(emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim,
                               n_layer=self.enc_layers,
                               dropout=self.dropout,
                               bi=self.enc_bi,
                               embedding=encoder_emb)
        self.PriorGaussian = torch.distributions.Normal(
            gpu_wrapper(torch.zeros(self.latent_dim)),
            gpu_wrapper(torch.ones(self.latent_dim)))
        self.PosteriorGaussian = Gaussian(in_dim=self.hid_dim * self.n_dir *
                                          self.enc_layers,
                                          out_dim=self.latent_dim)
        self.Decoder = Decoder(voc_size=self.voc_size,
                               latent_dim=self.latent_dim,
                               emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim * self.n_dir,
                               n_layer=self.dec_layers,
                               dropout=self.dropout,
                               max_len=self.dec_max_len,
                               beam_size=self.beam_size,
                               WEAtt_type=self.WEAtt_type,
                               embedding=decoder_emb)
        self.BoW = nn.Linear(self.latent_dim, self.voc_size)

        self.criterionSeq = SeqLoss(voc_size=self.voc_size,
                                    pad=pad_id,
                                    end=None,
                                    unk=None)

    def visualize(self, go, sent_len, bare):
        B = bare.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussian_dist, _ = self.PosteriorGaussian(last_states)
        # _.shape = (n_batch, latent_dim)

        samples = gaussian_dist.sample(torch.Size([1])).squeeze(
            0)  # shape = (n_batch, latent_dim)

        return samples

    def estimate_mi(self, go, sent_len, bare, n_sample):
        B = go.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussian_dist, _ = self.PosteriorGaussian(last_states)
        # _.shape = (n_batch, latent_dim)

        # ----- Importance sampling estimation -----
        mi, sampled_latents = self.importance_sampling_mi(
            gaussian_dist=gaussian_dist,
            n_sample=n_sample)  # shape = (n_batch, )

        return mi, sampled_latents

    def importance_sampling_mi(self, gaussian_dist, n_sample):
        assert n_sample % _n_sample == 0

        B = gaussian_dist.mean.shape[0]

        samplify = {'log_qz': [], 'log_qzx': [], 'z': []}
        for sample_id in range(n_sample // _n_sample):
            # ----- Sampling -----
            _z = gaussian_dist.rsample(torch.Size(
                [_n_sample]))  # shape = (_n_sample, n_batch, latent_dim)
            assert tuple(_z.shape) == (_n_sample, B, self.latent_dim)

            _log_qzx = gaussian_dist.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)
            _log_qz = gaussian_dist.log_prob(
                _z.unsqueeze(2).expand(-1, -1, B, -1)).sum(
                    3)  # shape = (_n_sample, n_batch, n_batch)
            # Exclude itself.
            _log_qz.masked_fill_(
                gpu_wrapper(torch.eye(B).long()).eq(1).unsqueeze(0).expand(
                    _n_sample, -1, -1),
                -float('inf'))  # shape = (_n_sample, n_batch, n_batch)
            _log_qz = (log_sum_exp(_log_qz, dim=2) - np.log(B - 1)
                       )  # shape = (_n_sample, n_batch)

            samplify['log_qzx'].append(
                _log_qzx)  # shape = (_n_sample, n_batch)
            samplify['log_qz'].append(_log_qz)  # shape = (_n_sample, n_batch)
            samplify['z'].append(_z)  # shape = (_n_sample, n_batch, out_dim)

        for key in samplify.keys():
            samplify[key] = torch.cat(samplify[key],
                                      dim=0)  # shape = (n_sample, ?)

        # ----- Importance sampling for MI -----
        mi = samplify['log_qzx'].mean(0) - samplify['log_qz'].mean(0)

        return mi, samplify['z'].transpose(0, 1)

    def test_lm(self, go, sent_len, bare, eos, n_sample):
        B = go.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussian_dist, _ = self.PosteriorGaussian(last_states)
        # _.shape = (n_batch, latent_dim)

        # ----- Importance sampling estimation -----
        xent, nll, kl, sampled_latents = self.importance_sampling(
            gaussian_dist=gaussian_dist, go=go, eos=eos, n_sample=n_sample)
        # xent.shape = (n_batch, )
        # nll.shape = (n_batch, )
        # kl.shape = (n_batch, )
        # sampled_latents.shape = (n_batch, n_sample, latent_dim)

        return xent, nll, kl, sampled_latents

    def importance_sampling(self, gaussian_dist, go, eos, n_sample):
        B = go.shape[0]
        assert n_sample % _n_sample == 0

        samplify = {
            'xent': [],
            'log_pz': [],
            'log_pxz': [],
            'log_qzx': [],
            'z': []
        }
        for sample_id in range(n_sample // _n_sample):

            # ----- Sampling -----
            _z = gaussian_dist.rsample(torch.Size(
                [_n_sample]))  # shape = (_n_sample, n_batch, latent_dim)
            assert tuple(_z.shape) == (_n_sample, B, self.latent_dim)

            # ----- Initial Decoding States -----
            assert self.enc_bi
            _init_states = gpu_wrapper(
                torch.zeros([
                    self.enc_layers, _n_sample * B, self.n_dir * self.hid_dim
                ])).float(
                )  # shape = (layers, _n_sample * n_batch, n_dir * hid_dim)

            # ----- Importance sampling for NLL -----
            _logits = self.Decoder(
                init_states=
                _init_states,  # shape = (layers, _n_sample * n_batch, n_dir * hid_dim)
                latent_vector=_z.contiguous().view(
                    _n_sample * B,
                    self.latent_dim),  # shape = (_n_sample * n_batch, out_dim)
                helper=go.unsqueeze(0).expand(
                    _n_sample, -1, -1).contiguous().view(
                        _n_sample * B,
                        -1),  # shape = (_n_sample * n_batch, 15)
                test_lm=True)  # shape = (_n_sample * n_batch, 16, V)
            _xent = self.criterionSeq(
                _logits,  # shape = (_n_sample * n_batch, 16, V)
                eos.unsqueeze(0).expand(_n_sample, -1, -1).contiguous().view(
                    _n_sample * B, -1),  # shape = (_n_sample * n_batch, 16)
                keep_batch=True).view(_n_sample,
                                      B)  # shape = (_n_sample, n_batch)

            _log_pz = self.PriorGaussian.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)
            _log_pxz = -_xent  # shape = (_n_sample, n_batch)
            _log_qzx = gaussian_dist.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)

            samplify['xent'].append(_xent)  # shape = (_n_sample, n_batch)
            samplify['log_pz'].append(_log_pz)  # shape = (_n_sample, n_batch)
            samplify['log_pxz'].append(
                _log_pxz)  # shape = (_n_sample, n_batch)
            samplify['log_qzx'].append(
                _log_qzx)  # shape = (_n_sample, n_batch)
            samplify['z'].append(_z)  # shape = (_n_sample, n_batch, out_dim)

        for key in samplify.keys():
            samplify[key] = torch.cat(samplify[key],
                                      dim=0)  # shape = (n_sample, ?)

        ll = log_sum_exp(
            samplify['log_pz'] + samplify['log_pxz'] - samplify['log_qzx'],
            dim=0) - np.log(n_sample)  # shape = (n_batch, )
        nll = -ll  # shape = (n_batch, )

        # ----- Importance sampling for KL -----
        # kl = kl_with_isogaussian(gaussian_dist)  # shape = (n_batch, )
        kl = (samplify['log_qzx'] - samplify['log_pz']).mean(
            0)  # shape = (n_batch, )

        return samplify['xent'].mean(0), nll, kl, samplify['z'].transpose(0, 1)

    def generate_gaussian(self, B):
        return self.PriorGaussian.sample(torch.Size(
            [B]))  # shape = (n_batch, emb_dim)

    def gen_interps(self, bareA, sent_lenA, bareB, sent_lenB, go, n_interps):
        """

        :param bareA: shape = (n_batch, 15)
        :param sent_lenA: shape = (n_batch, )
        :param bareB: shape = (n_batch, 15)
        :param sent_lenB: shape = (n_batch, )
        :param go: shape = (n_batch, 16)
        :param n_interps: int.
        :return:
        """
        B = go.shape[0]

        # ---------- A ----------
        # ----- Encoding -----
        _, last_statesA = self.Encoder(bareA, sent_lenA)
        # _.shape = (n_batch, 15, n_dir * hid_dim)
        # last_statesA.shape = (layers * n_dir, n_batch, hid_dim)
        last_statesA = last_statesA.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussA, _ = self.PosteriorGaussian(last_statesA)
        z0A = gaussA.mean
        # z0A.shape = (n_batch, latent_dim)

        # ---------- B ----------
        # ----- Encoding -----
        _, last_statesB = self.Encoder(bareB, sent_lenB)
        # _.shape = (n_batch, 15, n_dir * hid_dim)
        # last_statesB.shape = (layers * n_dir, n_batch, hid_dim)
        last_statesB = last_statesB.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussB, _ = self.PosteriorGaussian(last_statesB)
        z0B = gaussB.mean
        # z0B.shape = (n_batch, latent_dim)

        # ----- Initial Decoding States -----
        assert self.enc_bi
        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        interps = [[] for _ in range(B)]
        for in_id in range(n_interps + 2):
            _zk = z0A * ((n_interps - in_id + 1) / (n_interps + 1)) + z0B * (
                in_id / (n_interps + 1))  # shape = (n_batch, latent_dim)
            _interp = self.Decoder(init_states=init_states,
                                   latent_vector=_zk,
                                   helper=go)
            for b_id, _b_interp in enumerate(_interp):
                interps[b_id].append(_b_interp)
        return interps

    def sample_from_prior(self, go):
        """

        :param go: shape = (n_batch, 16)
        :return:
        """
        B = go.shape[0]

        # ----- Prior Network -----
        latent_vector = self.generate_gaussian(
            B)  # shape = (n_batch, latent_dim)

        # ----- Initial Decoding States -----
        assert self.enc_bi
        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        return self.Decoder(init_states=init_states,
                            latent_vector=latent_vector,
                            helper=go)

    def sample_from_posterior(self, bare, sent_len, n_sample):
        """

        :param bare: shape = (n_batch, 15)
        :param sent_len: shape = (n_batch, )
        :param n_sample: int
        :return: shape = (n_batch, n_samples, latent_dim)
        """

        B = bare.shape[0]
        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussian_dist, _ = self.PosteriorGaussian(last_states)

        samples = gaussian_dist.sample(torch.Size(
            [n_sample]))  # shape = (n_sample, n_batch, latent_dim)
        samples = samples.transpose(
            0, 1).contiguous()  # shape = (n_batch, n_sample, latent_dim)

        return samples

    def decode_from(self, latents, go):
        """

        :param latents: shape = (n_batch, latent_dim)
        :param go: shape = (n_batch, 16)
        :return:
        """
        B = latents.shape[0]

        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        return self.Decoder(init_states=init_states,
                            latent_vector=latents,
                            helper=go)

    def forward(self, go, sent_len=None, bare=None):
        """

        :param go: shape = (n_batch, 16)
        :param sent_len: shape = (n_batch, ) or None
        :param bare: shape = (n_batch, 15) or None
        :return:
        """
        B = go.shape[0]

        if not self.training:
            raise NotImplementedError()
        else:
            # ----- Encoding -----
            outputs, last_states = self.Encoder(bare, sent_len)
            # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
            # last_states.shape = (layers * n_dir, n_batch, hid_dim)
            last_states = last_states.transpose(0, 1).contiguous().view(
                B, -1)  # shape = (n_batch, layers * n_dir * hid_dim

            # ----- Posterior Network -----
            gaussian_dist, latent_vector = self.PosteriorGaussian(last_states)
            # latent_vector.shape = (n_batch, latent_dim)

            # ----- Bag-of-Words logits -----
            BoW_logits = self.BoW(latent_vector)  # shape = (n_bathc, voc_size)

            # ----- Initial Decoding States -----
            assert self.enc_bi
            init_states = gpu_wrapper(
                torch.zeros([
                    self.enc_layers, B, self.n_dir * self.hid_dim
                ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

            return self.Decoder(
                init_states=init_states,
                latent_vector=latent_vector,
                helper=go), gaussian_dist, latent_vector, BoW_logits

    def saliency(self, go, sent_len=None, bare=None):
        B = go.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        gaussian_dist, latent_vector = self.PosteriorGaussian(last_states)
        # latent_vector.shape = (n_batch, latent_dim)

        # ----- Bag-of-Words logits -----
        BoW_logits = self.BoW(latent_vector)  # shape = (n_bathc, voc_size)

        # ----- Initial Decoding States -----
        assert self.enc_bi
        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        logits = self.Decoder(init_states=init_states,
                              latent_vector=latent_vector,
                              helper=go)

        return logits, gaussian_dist, self.Decoder.toInit(
            latent_vector), last_states
コード例 #5
0
ファイル: dae.py プロジェクト: ChenWu98/Coupled-VAE
class DAE(nn.Module):
    def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout,
                 enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb,
                 decoder_emb, pad_id):
        super(DAE, self).__init__()
        assert encoder_emb.num_embeddings == decoder_emb.num_embeddings
        assert encoder_emb.embedding_dim == decoder_emb.embedding_dim
        self.voc_size = encoder_emb.num_embeddings
        self.emb_dim = encoder_emb.embedding_dim
        self.hid_dim = hid_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.dropout = dropout
        self.enc_bi = enc_bi
        self.n_dir = 2 if self.enc_bi else 1
        self.dec_max_len = dec_max_len
        self.beam_size = beam_size
        self.WEAtt_type = WEAtt_type
        self.latent_dim = latent_dim

        self.Encoder = Encoder(emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim,
                               n_layer=self.enc_layers,
                               dropout=self.dropout,
                               bi=self.enc_bi,
                               embedding=encoder_emb)
        self.PriorGaussian = torch.distributions.Normal(
            gpu_wrapper(torch.zeros(self.latent_dim)),
            gpu_wrapper(torch.ones(self.latent_dim)))
        self.toLatent = nn.Linear(self.hid_dim * self.n_dir * self.enc_layers,
                                  self.latent_dim)
        self.Decoder = Decoder(voc_size=self.voc_size,
                               latent_dim=self.latent_dim,
                               emb_dim=self.emb_dim,
                               hid_dim=self.hid_dim * self.n_dir,
                               n_layer=self.dec_layers,
                               dropout=self.dropout,
                               max_len=self.dec_max_len,
                               beam_size=self.beam_size,
                               WEAtt_type=self.WEAtt_type,
                               embedding=decoder_emb)

        self.criterionSeq = SeqLoss(voc_size=self.voc_size,
                                    pad=pad_id,
                                    end=None,
                                    unk=None)

    def visualize(self, go, sent_len, bare):
        B = bare.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        samples = self.toLatent(last_states)  # shape = (n_batch, latent_dim)

        return samples

    def test_lm(self, go, sent_len, bare, eos, n_sample):
        B = go.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        latent_vector = self.toLatent(
            last_states.transpose(0, 1).contiguous().view(
                B, -1))  # shape = (n_batch, latent_dim)

        # ----- Initial Decoding States -----
        assert self.enc_bi
        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        logits = self.Decoder(init_states=init_states,
                              latent_vector=latent_vector,
                              helper=go,
                              test_lm=True)  # shape = (n_batch, 16, V)
        xent = self.criterionSeq(logits, eos,
                                 keep_batch=True)  # shape = (n_batch, )
        kl = torch.zeros_like(xent) + float('inf')  # shape = (n_batch, )

        nll = xent + kl  # shape = (n_batch, )

        return xent, nll, kl, latent_vector

    def generate_gaussian(self, B):
        return self.PriorGaussian.sample(torch.Size(
            [B]))  # shape = (n_batch, emb_dim)

    def forward(self, go, sent_len=None, bare=None):
        """

        :param go: shape = (n_batch, 16)
        :param sent_len: shape = (n_batch, ) or None
        :param bare: shape = (n_batch, 15) or None
        :return:
        """
        B = go.shape[0]

        if not self.training:
            # ----- Prior Network -----
            latent_vector = self.generate_gaussian(
                B)  # shape = (n_batch, latent_dim)

            # ----- Initial Decoding States -----
            assert self.enc_bi
            init_states = gpu_wrapper(
                torch.zeros([
                    self.enc_layers, B, self.n_dir * self.hid_dim
                ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

            return self.Decoder(init_states=init_states,
                                latent_vector=latent_vector,
                                helper=go)
        else:
            # ----- Encoding -----
            outputs, last_states = self.Encoder(bare, sent_len)
            # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
            # last_states.shape = (layers * n_dir, n_batch, hid_dim)
            latent_vector = self.toLatent(
                last_states.transpose(0, 1).contiguous().view(
                    B, -1))  # shape = (n_batch, emb_dim)

            # ----- Initial Decoding States -----
            assert self.enc_bi
            init_states = gpu_wrapper(
                torch.zeros([
                    self.enc_layers, B, self.n_dir * self.hid_dim
                ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

            return self.Decoder(init_states=init_states,
                                latent_vector=latent_vector,
                                helper=go), latent_vector

    def saliency(self, go, sent_len=None, bare=None):
        B = go.shape[0]

        # ----- Encoding -----
        outputs, last_states = self.Encoder(bare, sent_len)
        # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim)
        # last_states.shape = (layers * n_dir, n_batch, hid_dim)
        last_states = last_states.transpose(0, 1).contiguous().view(
            B, -1)  # shape = (n_batch, layers * n_dir * hid_dim)

        # ----- Posterior Network -----
        latent_vector = self.toLatent(last_states)
        # latent_vector.shape = (n_batch, latent_dim)

        # ----- Initial Decoding States -----
        assert self.enc_bi
        init_states = gpu_wrapper(
            torch.zeros([
                self.enc_layers, B, self.n_dir * self.hid_dim
            ])).float()  # shape = (layers, n_batch, n_dir * hid_dim)

        logits = self.Decoder(init_states=init_states,
                              latent_vector=latent_vector,
                              helper=go)

        return logits, self.Decoder.toInit(latent_vector), last_states