def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id): super(VAE, self).__init__() assert encoder_emb.num_embeddings == decoder_emb.num_embeddings assert encoder_emb.embedding_dim == decoder_emb.embedding_dim self.voc_size = encoder_emb.num_embeddings self.emb_dim = encoder_emb.embedding_dim self.hid_dim = hid_dim self.enc_layers = enc_layers self.dec_layers = dec_layers self.dropout = dropout self.enc_bi = enc_bi self.n_dir = 2 if self.enc_bi else 1 self.dec_max_len = dec_max_len self.beam_size = beam_size self.WEAtt_type = WEAtt_type self.latent_dim = latent_dim self.Encoder = Encoder(emb_dim=self.emb_dim, hid_dim=self.hid_dim, n_layer=self.enc_layers, dropout=self.dropout, bi=self.enc_bi, embedding=encoder_emb) self.PriorGaussian = torch.distributions.Normal( gpu_wrapper(torch.zeros(self.latent_dim)), gpu_wrapper(torch.ones(self.latent_dim))) self.PosteriorGaussian = Gaussian(in_dim=self.hid_dim * self.n_dir * self.enc_layers, out_dim=self.latent_dim) self.Decoder = Decoder(voc_size=self.voc_size, latent_dim=self.latent_dim, emb_dim=self.emb_dim, hid_dim=self.hid_dim * self.n_dir, n_layer=self.dec_layers, dropout=self.dropout, max_len=self.dec_max_len, beam_size=self.beam_size, WEAtt_type=self.WEAtt_type, embedding=decoder_emb) self.BoW = nn.Linear(self.latent_dim, self.voc_size) self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None)
def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id): super(SEQ2SEQ, self).__init__() assert encoder_emb.num_embeddings == decoder_emb.num_embeddings assert encoder_emb.embedding_dim == decoder_emb.embedding_dim self.voc_size = encoder_emb.num_embeddings self.emb_dim = encoder_emb.embedding_dim self.hid_dim = hid_dim self.enc_layers = enc_layers self.dec_layers = dec_layers self.dropout = dropout self.enc_bi = enc_bi self.n_dir = 2 if self.enc_bi else 1 self.dec_max_len = dec_max_len self.beam_size = beam_size self.WEAtt_type = WEAtt_type self.latent_dim = latent_dim self.PostEncoder = Encoder(emb_dim=self.emb_dim, hid_dim=self.hid_dim, n_layer=self.enc_layers, dropout=self.dropout, bi=self.enc_bi, embedding=encoder_emb) self.PostRepr = nn.Linear(self.hid_dim * self.n_dir * self.enc_layers, self.emb_dim) self.Decoder = Decoder(voc_size=self.voc_size, latent_dim=self.latent_dim, emb_dim=self.emb_dim, hid_dim=self.hid_dim * self.n_dir, n_layer=self.dec_layers, dropout=self.dropout, max_len=self.dec_max_len, beam_size=self.beam_size, WEAtt_type=self.WEAtt_type, embedding=decoder_emb) self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None)
def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id): super(S_VAE, self).__init__() assert encoder_emb.num_embeddings == decoder_emb.num_embeddings assert encoder_emb.embedding_dim == decoder_emb.embedding_dim self.voc_size = encoder_emb.num_embeddings self.emb_dim = encoder_emb.embedding_dim self.hid_dim = hid_dim self.enc_layers = enc_layers self.dec_layers = dec_layers self.dropout = dropout self.enc_bi = enc_bi self.n_dir = 2 if self.enc_bi else 1 self.dec_max_len = dec_max_len self.beam_size = beam_size self.WEAtt_type = WEAtt_type self.latent_dim = latent_dim self.Encoder = Encoder(emb_dim=self.emb_dim, hid_dim=self.hid_dim, n_layer=self.enc_layers, dropout=self.dropout, bi=self.enc_bi, embedding=encoder_emb) self.PriorUniform = HypersphericalUniform(dim=self.latent_dim) self.PosteriorVMF = VonMisesFisherModule(in_dim=self.hid_dim * self.n_dir * self.enc_layers, out_dim=self.latent_dim) self.Decoder = Decoder(voc_size=self.voc_size, latent_dim=self.latent_dim, emb_dim=self.emb_dim, hid_dim=self.hid_dim * self.n_dir, n_layer=self.dec_layers, dropout=self.dropout, max_len=self.dec_max_len, beam_size=self.beam_size, WEAtt_type=self.WEAtt_type, embedding=decoder_emb) self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None)
class VAE(nn.Module): def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id): super(VAE, self).__init__() assert encoder_emb.num_embeddings == decoder_emb.num_embeddings assert encoder_emb.embedding_dim == decoder_emb.embedding_dim self.voc_size = encoder_emb.num_embeddings self.emb_dim = encoder_emb.embedding_dim self.hid_dim = hid_dim self.enc_layers = enc_layers self.dec_layers = dec_layers self.dropout = dropout self.enc_bi = enc_bi self.n_dir = 2 if self.enc_bi else 1 self.dec_max_len = dec_max_len self.beam_size = beam_size self.WEAtt_type = WEAtt_type self.latent_dim = latent_dim self.Encoder = Encoder(emb_dim=self.emb_dim, hid_dim=self.hid_dim, n_layer=self.enc_layers, dropout=self.dropout, bi=self.enc_bi, embedding=encoder_emb) self.PriorGaussian = torch.distributions.Normal( gpu_wrapper(torch.zeros(self.latent_dim)), gpu_wrapper(torch.ones(self.latent_dim))) self.PosteriorGaussian = Gaussian(in_dim=self.hid_dim * self.n_dir * self.enc_layers, out_dim=self.latent_dim) self.Decoder = Decoder(voc_size=self.voc_size, latent_dim=self.latent_dim, emb_dim=self.emb_dim, hid_dim=self.hid_dim * self.n_dir, n_layer=self.dec_layers, dropout=self.dropout, max_len=self.dec_max_len, beam_size=self.beam_size, WEAtt_type=self.WEAtt_type, embedding=decoder_emb) self.BoW = nn.Linear(self.latent_dim, self.voc_size) self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None) def visualize(self, go, sent_len, bare): B = bare.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussian_dist, _ = self.PosteriorGaussian(last_states) # _.shape = (n_batch, latent_dim) samples = gaussian_dist.sample(torch.Size([1])).squeeze( 0) # shape = (n_batch, latent_dim) return samples def estimate_mi(self, go, sent_len, bare, n_sample): B = go.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussian_dist, _ = self.PosteriorGaussian(last_states) # _.shape = (n_batch, latent_dim) # ----- Importance sampling estimation ----- mi, sampled_latents = self.importance_sampling_mi( gaussian_dist=gaussian_dist, n_sample=n_sample) # shape = (n_batch, ) return mi, sampled_latents def importance_sampling_mi(self, gaussian_dist, n_sample): assert n_sample % _n_sample == 0 B = gaussian_dist.mean.shape[0] samplify = {'log_qz': [], 'log_qzx': [], 'z': []} for sample_id in range(n_sample // _n_sample): # ----- Sampling ----- _z = gaussian_dist.rsample(torch.Size( [_n_sample])) # shape = (_n_sample, n_batch, latent_dim) assert tuple(_z.shape) == (_n_sample, B, self.latent_dim) _log_qzx = gaussian_dist.log_prob(_z).sum( 2) # shape = (_n_sample, n_batch) _log_qz = gaussian_dist.log_prob( _z.unsqueeze(2).expand(-1, -1, B, -1)).sum( 3) # shape = (_n_sample, n_batch, n_batch) # Exclude itself. _log_qz.masked_fill_( gpu_wrapper(torch.eye(B).long()).eq(1).unsqueeze(0).expand( _n_sample, -1, -1), -float('inf')) # shape = (_n_sample, n_batch, n_batch) _log_qz = (log_sum_exp(_log_qz, dim=2) - np.log(B - 1) ) # shape = (_n_sample, n_batch) samplify['log_qzx'].append( _log_qzx) # shape = (_n_sample, n_batch) samplify['log_qz'].append(_log_qz) # shape = (_n_sample, n_batch) samplify['z'].append(_z) # shape = (_n_sample, n_batch, out_dim) for key in samplify.keys(): samplify[key] = torch.cat(samplify[key], dim=0) # shape = (n_sample, ?) # ----- Importance sampling for MI ----- mi = samplify['log_qzx'].mean(0) - samplify['log_qz'].mean(0) return mi, samplify['z'].transpose(0, 1) def test_lm(self, go, sent_len, bare, eos, n_sample): B = go.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussian_dist, _ = self.PosteriorGaussian(last_states) # _.shape = (n_batch, latent_dim) # ----- Importance sampling estimation ----- xent, nll, kl, sampled_latents = self.importance_sampling( gaussian_dist=gaussian_dist, go=go, eos=eos, n_sample=n_sample) # xent.shape = (n_batch, ) # nll.shape = (n_batch, ) # kl.shape = (n_batch, ) # sampled_latents.shape = (n_batch, n_sample, latent_dim) return xent, nll, kl, sampled_latents def importance_sampling(self, gaussian_dist, go, eos, n_sample): B = go.shape[0] assert n_sample % _n_sample == 0 samplify = { 'xent': [], 'log_pz': [], 'log_pxz': [], 'log_qzx': [], 'z': [] } for sample_id in range(n_sample // _n_sample): # ----- Sampling ----- _z = gaussian_dist.rsample(torch.Size( [_n_sample])) # shape = (_n_sample, n_batch, latent_dim) assert tuple(_z.shape) == (_n_sample, B, self.latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi _init_states = gpu_wrapper( torch.zeros([ self.enc_layers, _n_sample * B, self.n_dir * self.hid_dim ])).float( ) # shape = (layers, _n_sample * n_batch, n_dir * hid_dim) # ----- Importance sampling for NLL ----- _logits = self.Decoder( init_states= _init_states, # shape = (layers, _n_sample * n_batch, n_dir * hid_dim) latent_vector=_z.contiguous().view( _n_sample * B, self.latent_dim), # shape = (_n_sample * n_batch, out_dim) helper=go.unsqueeze(0).expand( _n_sample, -1, -1).contiguous().view( _n_sample * B, -1), # shape = (_n_sample * n_batch, 15) test_lm=True) # shape = (_n_sample * n_batch, 16, V) _xent = self.criterionSeq( _logits, # shape = (_n_sample * n_batch, 16, V) eos.unsqueeze(0).expand(_n_sample, -1, -1).contiguous().view( _n_sample * B, -1), # shape = (_n_sample * n_batch, 16) keep_batch=True).view(_n_sample, B) # shape = (_n_sample, n_batch) _log_pz = self.PriorGaussian.log_prob(_z).sum( 2) # shape = (_n_sample, n_batch) _log_pxz = -_xent # shape = (_n_sample, n_batch) _log_qzx = gaussian_dist.log_prob(_z).sum( 2) # shape = (_n_sample, n_batch) samplify['xent'].append(_xent) # shape = (_n_sample, n_batch) samplify['log_pz'].append(_log_pz) # shape = (_n_sample, n_batch) samplify['log_pxz'].append( _log_pxz) # shape = (_n_sample, n_batch) samplify['log_qzx'].append( _log_qzx) # shape = (_n_sample, n_batch) samplify['z'].append(_z) # shape = (_n_sample, n_batch, out_dim) for key in samplify.keys(): samplify[key] = torch.cat(samplify[key], dim=0) # shape = (n_sample, ?) ll = log_sum_exp( samplify['log_pz'] + samplify['log_pxz'] - samplify['log_qzx'], dim=0) - np.log(n_sample) # shape = (n_batch, ) nll = -ll # shape = (n_batch, ) # ----- Importance sampling for KL ----- # kl = kl_with_isogaussian(gaussian_dist) # shape = (n_batch, ) kl = (samplify['log_qzx'] - samplify['log_pz']).mean( 0) # shape = (n_batch, ) return samplify['xent'].mean(0), nll, kl, samplify['z'].transpose(0, 1) def generate_gaussian(self, B): return self.PriorGaussian.sample(torch.Size( [B])) # shape = (n_batch, emb_dim) def gen_interps(self, bareA, sent_lenA, bareB, sent_lenB, go, n_interps): """ :param bareA: shape = (n_batch, 15) :param sent_lenA: shape = (n_batch, ) :param bareB: shape = (n_batch, 15) :param sent_lenB: shape = (n_batch, ) :param go: shape = (n_batch, 16) :param n_interps: int. :return: """ B = go.shape[0] # ---------- A ---------- # ----- Encoding ----- _, last_statesA = self.Encoder(bareA, sent_lenA) # _.shape = (n_batch, 15, n_dir * hid_dim) # last_statesA.shape = (layers * n_dir, n_batch, hid_dim) last_statesA = last_statesA.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussA, _ = self.PosteriorGaussian(last_statesA) z0A = gaussA.mean # z0A.shape = (n_batch, latent_dim) # ---------- B ---------- # ----- Encoding ----- _, last_statesB = self.Encoder(bareB, sent_lenB) # _.shape = (n_batch, 15, n_dir * hid_dim) # last_statesB.shape = (layers * n_dir, n_batch, hid_dim) last_statesB = last_statesB.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussB, _ = self.PosteriorGaussian(last_statesB) z0B = gaussB.mean # z0B.shape = (n_batch, latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) interps = [[] for _ in range(B)] for in_id in range(n_interps + 2): _zk = z0A * ((n_interps - in_id + 1) / (n_interps + 1)) + z0B * ( in_id / (n_interps + 1)) # shape = (n_batch, latent_dim) _interp = self.Decoder(init_states=init_states, latent_vector=_zk, helper=go) for b_id, _b_interp in enumerate(_interp): interps[b_id].append(_b_interp) return interps def sample_from_prior(self, go): """ :param go: shape = (n_batch, 16) :return: """ B = go.shape[0] # ----- Prior Network ----- latent_vector = self.generate_gaussian( B) # shape = (n_batch, latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) return self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go) def sample_from_posterior(self, bare, sent_len, n_sample): """ :param bare: shape = (n_batch, 15) :param sent_len: shape = (n_batch, ) :param n_sample: int :return: shape = (n_batch, n_samples, latent_dim) """ B = bare.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussian_dist, _ = self.PosteriorGaussian(last_states) samples = gaussian_dist.sample(torch.Size( [n_sample])) # shape = (n_sample, n_batch, latent_dim) samples = samples.transpose( 0, 1).contiguous() # shape = (n_batch, n_sample, latent_dim) return samples def decode_from(self, latents, go): """ :param latents: shape = (n_batch, latent_dim) :param go: shape = (n_batch, 16) :return: """ B = latents.shape[0] init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) return self.Decoder(init_states=init_states, latent_vector=latents, helper=go) def forward(self, go, sent_len=None, bare=None): """ :param go: shape = (n_batch, 16) :param sent_len: shape = (n_batch, ) or None :param bare: shape = (n_batch, 15) or None :return: """ B = go.shape[0] if not self.training: raise NotImplementedError() else: # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim # ----- Posterior Network ----- gaussian_dist, latent_vector = self.PosteriorGaussian(last_states) # latent_vector.shape = (n_batch, latent_dim) # ----- Bag-of-Words logits ----- BoW_logits = self.BoW(latent_vector) # shape = (n_bathc, voc_size) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) return self.Decoder( init_states=init_states, latent_vector=latent_vector, helper=go), gaussian_dist, latent_vector, BoW_logits def saliency(self, go, sent_len=None, bare=None): B = go.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- gaussian_dist, latent_vector = self.PosteriorGaussian(last_states) # latent_vector.shape = (n_batch, latent_dim) # ----- Bag-of-Words logits ----- BoW_logits = self.BoW(latent_vector) # shape = (n_bathc, voc_size) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) logits = self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go) return logits, gaussian_dist, self.Decoder.toInit( latent_vector), last_states
class DAE(nn.Module): def __init__(self, hid_dim, latent_dim, enc_layers, dec_layers, dropout, enc_bi, dec_max_len, beam_size, WEAtt_type, encoder_emb, decoder_emb, pad_id): super(DAE, self).__init__() assert encoder_emb.num_embeddings == decoder_emb.num_embeddings assert encoder_emb.embedding_dim == decoder_emb.embedding_dim self.voc_size = encoder_emb.num_embeddings self.emb_dim = encoder_emb.embedding_dim self.hid_dim = hid_dim self.enc_layers = enc_layers self.dec_layers = dec_layers self.dropout = dropout self.enc_bi = enc_bi self.n_dir = 2 if self.enc_bi else 1 self.dec_max_len = dec_max_len self.beam_size = beam_size self.WEAtt_type = WEAtt_type self.latent_dim = latent_dim self.Encoder = Encoder(emb_dim=self.emb_dim, hid_dim=self.hid_dim, n_layer=self.enc_layers, dropout=self.dropout, bi=self.enc_bi, embedding=encoder_emb) self.PriorGaussian = torch.distributions.Normal( gpu_wrapper(torch.zeros(self.latent_dim)), gpu_wrapper(torch.ones(self.latent_dim))) self.toLatent = nn.Linear(self.hid_dim * self.n_dir * self.enc_layers, self.latent_dim) self.Decoder = Decoder(voc_size=self.voc_size, latent_dim=self.latent_dim, emb_dim=self.emb_dim, hid_dim=self.hid_dim * self.n_dir, n_layer=self.dec_layers, dropout=self.dropout, max_len=self.dec_max_len, beam_size=self.beam_size, WEAtt_type=self.WEAtt_type, embedding=decoder_emb) self.criterionSeq = SeqLoss(voc_size=self.voc_size, pad=pad_id, end=None, unk=None) def visualize(self, go, sent_len, bare): B = bare.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- samples = self.toLatent(last_states) # shape = (n_batch, latent_dim) return samples def test_lm(self, go, sent_len, bare, eos, n_sample): B = go.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) latent_vector = self.toLatent( last_states.transpose(0, 1).contiguous().view( B, -1)) # shape = (n_batch, latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) logits = self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go, test_lm=True) # shape = (n_batch, 16, V) xent = self.criterionSeq(logits, eos, keep_batch=True) # shape = (n_batch, ) kl = torch.zeros_like(xent) + float('inf') # shape = (n_batch, ) nll = xent + kl # shape = (n_batch, ) return xent, nll, kl, latent_vector def generate_gaussian(self, B): return self.PriorGaussian.sample(torch.Size( [B])) # shape = (n_batch, emb_dim) def forward(self, go, sent_len=None, bare=None): """ :param go: shape = (n_batch, 16) :param sent_len: shape = (n_batch, ) or None :param bare: shape = (n_batch, 15) or None :return: """ B = go.shape[0] if not self.training: # ----- Prior Network ----- latent_vector = self.generate_gaussian( B) # shape = (n_batch, latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) return self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go) else: # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) latent_vector = self.toLatent( last_states.transpose(0, 1).contiguous().view( B, -1)) # shape = (n_batch, emb_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) return self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go), latent_vector def saliency(self, go, sent_len=None, bare=None): B = go.shape[0] # ----- Encoding ----- outputs, last_states = self.Encoder(bare, sent_len) # ext_outputs.shape = (n_batch, 15, n_dir * hid_dim) # last_states.shape = (layers * n_dir, n_batch, hid_dim) last_states = last_states.transpose(0, 1).contiguous().view( B, -1) # shape = (n_batch, layers * n_dir * hid_dim) # ----- Posterior Network ----- latent_vector = self.toLatent(last_states) # latent_vector.shape = (n_batch, latent_dim) # ----- Initial Decoding States ----- assert self.enc_bi init_states = gpu_wrapper( torch.zeros([ self.enc_layers, B, self.n_dir * self.hid_dim ])).float() # shape = (layers, n_batch, n_dir * hid_dim) logits = self.Decoder(init_states=init_states, latent_vector=latent_vector, helper=go) return logits, self.Decoder.toInit(latent_vector), last_states