def __init__(self, rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=0.5, word_dropout=0.5, gpu=True): super(AttnGRU_VNMT, self).__init__() #self.word_dropout = 1.0#0.75 self.word_dropout = word_dropout self.word_drop = nn.Dropout(word_dropout) self.rnn_type = rnn_type self.dec_type = 'attn' self.n_layers = n_layers self.embeddings = nn.Embedding(vocab_size, embedding_dim) # encoder for x self.encoder = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout) # encoder for y #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, # n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True #) ################################################ # Only supports 1-layer decoder for now ################################################ self.decoder = CustomAttnDecoderRNN('CustomGRU', embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=dropout, word_dropout=word_dropout)
def __init__(self, rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=0.5, word_dropout=0.5, gpu=True): super(VRAE_VNMT, self).__init__() self.word_dropout = word_dropout self.z_size = 1000 # concat size is absorbed by linear_mu_post etc, so z_size just needs to be equal with hidden_dim self.mode = 'vnmt' self.hidden_dim = hidden_dim self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.word_drop = nn.Dropout(word_dropout) self.rnn_type = rnn_type self.dec_type = 'attn' self.n_layers = n_layers self.linear_mu_prior = nn.Linear( hidden_dim, self.z_size) # hidden_dim*1 because we only pass x self.linear_sigma_prior = nn.Linear(hidden_dim, self.z_size) self.linear_mu_post = nn.Linear( hidden_dim * 2, self.z_size) # hidden_dim*2 because we pass x and y self.linear_sigma_post = nn.Linear(hidden_dim * 2, self.z_size) self.encoder_prior = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True) self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True) ################################################ # Only supports 1-layer decoder for now ################################################ self.decoder = CustomAttnDecoderRNN( 'CustomGRU', embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=dropout, word_dropout=word_dropout ) # > We use a fixed word dropout rate of 75% # for projecting z into the hidden dim of the decoder so that it can be added inside the GRU cells self.linear_z = nn.Linear( self.z_size, self.decoder.hidden_dim) # W_z^(2) and b_z^(2)
class VRAE_VNMT(nn.Module): def __init__(self, rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=0.5, word_dropout=0.5, gpu=True): super(VRAE_VNMT, self).__init__() self.word_dropout = word_dropout self.z_size = 1000 # concat size is absorbed by linear_mu_post etc, so z_size just needs to be equal with hidden_dim self.mode = 'vnmt' self.hidden_dim = hidden_dim self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.word_drop = nn.Dropout(word_dropout) self.rnn_type = rnn_type self.dec_type = 'attn' self.n_layers = n_layers self.linear_mu_prior = nn.Linear( hidden_dim, self.z_size) # hidden_dim*1 because we only pass x self.linear_sigma_prior = nn.Linear(hidden_dim, self.z_size) self.linear_mu_post = nn.Linear( hidden_dim * 2, self.z_size) # hidden_dim*2 because we pass x and y self.linear_sigma_post = nn.Linear(hidden_dim * 2, self.z_size) self.encoder_prior = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True) self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True) ################################################ # Only supports 1-layer decoder for now ################################################ self.decoder = CustomAttnDecoderRNN( 'CustomGRU', embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=dropout, word_dropout=word_dropout ) # > We use a fixed word dropout rate of 75% # for projecting z into the hidden dim of the decoder so that it can be added inside the GRU cells self.linear_z = nn.Linear( self.z_size, self.decoder.hidden_dim) # W_z^(2) and b_z^(2) def reparam_trick(self, mu, log_sigma): # the reason of log_sigma: https://www.reddit.com/r/MachineLearning/comments/74dx67/d_why_use_exponential_term_rather_than_log_term/ epsilon = torch.zeros(self.z_size).cuda() epsilon.normal_(0, 1) # 0 mean unit variance gaussian return Variable(epsilon * torch.exp(log_sigma.data * 0.5) + mu.data) def vnmt_loss(self, recon_x, target_x, mu_prior, log_sigma_prior, mu_post, log_sigma_post): seq_len, batch_size = target_x.size() loss_fn = nn.CrossEntropyLoss() loss = 0 for t in range(seq_len): loss += loss_fn(recon_x[t], target_x[t]) total_KLD = 0 sigma_prior = torch.exp(log_sigma_prior) sigma_post = torch.exp(log_sigma_post) KLD = ( log_sigma_prior - log_sigma_post + \ (sigma_post*sigma_post + (mu_post - mu_prior)*(mu_post - mu_prior)) / (2.0*sigma_prior*sigma_prior) - 0.5 ) ######################################################### # Be careful with the dimension when taking the sum!!! ######################################################### total_KLD += 1.0 * torch.sum(KLD, 1).mean().squeeze() return loss, total_KLD def batchNLLLoss(self, s, s_lengths, t, t_lengths, device, train=False): loss = 0 batch_size, seq_len = s.size() tt = t.clone() s_lengths, perm_idx = s_lengths.sort( 0, descending=True) # SORT YOUR TENSORS BY LENGTH! s.data = s.data[perm_idx] t.data = t.data[perm_idx] s = s.permute(1, 0).to(device) # seq_len x batch_size t = t.permute(1, 0).to(device) # seq_len x batch_size t_lengths, _perm_idx = t_lengths.sort( 0, descending=True) # SORT YOUR TENSORS BY LENGTH! tt.data = tt.data[_perm_idx] tt = tt.permute(1, 0).to(device) # seq_len x batch_size emb_s = self.embeddings(s) emb_t = self.embeddings(t) emb_tt = self.embeddings(tt) # for encoding target emb_t_shift = torch.zeros_like(emb_t) # 1 is the index for EOS_TOKEN emb_t_shift[1:, :, :] = emb_t[:-1, :, :] # shift the input sentences emb_t_shift = self.word_drop(emb_t_shift) ############################ # Encode x and y # ############################ # encode x for both the prior model and the poterior model. # linear layers are independent but the encoder to create annotation vectors is shared. enc_h_x = None encoder_outputs_x, encoder_hidden_x = self.encoder_prior( emb_s, s_lengths, enc_h_x) # torch.Size([12, 250, 256]) enc_h_x_mean = encoder_outputs_x.mean(0) if self.rnn_type == 'LSTM': encoder_hidden = encoder_hidden[0] enc_h = encoder_hidden_x if self.rnn_type == 'LSTM': dec_h = (enc_h[0][:self.decoder.n_layers].to(device), enc_h[1][:self.decoder.n_layers].to(device)) else: dec_h = enc_h[:self.decoder.n_layers].to(device) # encode y for both the poterior model. #enc_h_y = self.encoder_post.init_hidden(batch_size) # (the very first hidden) enc_h_y = None encoder_outputs_y, encoder_hidden_y = self.encoder_post( emb_tt, t_lengths, enc_h_y) # torch.Size([12, 250, 256]) enc_h_y_mean = encoder_outputs_y.mean(0) # mean pool y ############################ # Compute Prior # ############################ #print(enc_h_x_mean.size()) # 250, 6 mu_prior = self.linear_mu_prior(enc_h_x_mean) log_sigma_prior = self.linear_sigma_prior(enc_h_x_mean) ############################ # Compute Posterior # ############################ # define these for evaluation times mu_post = Variable(torch.zeros(batch_size, self.z_size)).to(device) log_sigma_post = Variable(torch.zeros(batch_size, self.z_size)).to(device) # concat h enc_h = torch.cat((enc_h_x_mean, enc_h_y_mean), 1) # h_z' => size: # get mu and sigma using the last hidden layer's output mu_post = self.linear_mu_post(enc_h) log_sigma_post = self.linear_sigma_post(enc_h) ##################################### # perform reparam trick and get z ##################################### # Obtain h_z z = self.reparam_trick(mu_post, log_sigma_post) ## project z into the decoder's hidden_dim so that it can be added in the GRU cells he = self.linear_z(z) # Take the last hidden state of the encoder and pass it to the decoder dec_h = encoder_hidden_x[:self.decoder.n_layers].to(device) ######################################################## # Decode using the last enc_h, context vectors, and z ######################################################## #dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]*batch_size])).long().to(device) #dec_inp = dec_inp.permute(1, 0) # 128x1 target_length = t.size()[0] all_decoder_outputs = Variable( torch.zeros(seq_len, batch_size, self.decoder.vocab_size)).to(device) use_target = True #True if random.random() < self.word_dropout else False for i in range(target_length): dec_emb = emb_t_shift[i] #out, dec_h = self.decoder.forward(dec_inp, dec_h, z) #out, dec_h, dec_attn = self.decoder.forward(dec_inp, dec_h, encoder_outputs, he) out, dec_h, dec_attn = self.decoder.forward( dec_emb, dec_h, encoder_outputs_x, he.unsqueeze(0)) if use_target: #dec_inp = target[i] # shape: batch_size, dec_emb = emb_t_shift[i] else: dec_inp = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size ])).long().to(device) all_decoder_outputs[i] = out # Compute the VNMT objective loss = self.vnmt_loss(all_decoder_outputs, t, mu_prior, log_sigma_prior, mu_post, log_sigma_post) return loss def sample(self, inp, max_seq_len): self.encoder_prior.eval() self.decoder.eval() pass def generate(self, inputs, ntokens, example, max_seq_len): """ Generate example """ batch_size = 1 self.encoder_prior.eval() self.decoder.eval() out_seq = [] dec_type = self.dec_type max_words = 100 input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(), volatile=True) input.data = input.data.cuda() for i, wd_idx in enumerate(example): input.data[0][i] = wd_idx input_words = [ inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len) ] # encoder initial h #h = self.encoder_prior.init_hidden(1) # (the very first hidden) inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(), volatile=True) for i in range(max_seq_len): inp.data[0][i] = EOS_TOKEN for i in range(len(example)): inp.data[0][i] = example[i] seq_lengths = torch.cuda.LongTensor([ len(x) - list(x).count(1) for x in inp.data.cpu().numpy() ]) # 1: <pad> inp = inp.permute(1, 0) ############################ # Encode x # ############################ ''' encoder_hiddens_x = Variable(torch.zeros(max_seq_len, batch_size, self.encoder_prior.hidden_dim)).cuda() if dec_type == 'vanilla': for i in range(max_seq_len): #enc_out, h = self.encoder_prior.forward(inp[i], h, seq_lengths) enc_out, h = self.encoder_prior.forward(inp[i], seq_lengths, h) encoder_hiddens_x[i] = h[0] elif dec_type == 'attn': enc_outs = Variable(torch.zeros(max_seq_len, 1, self.encoder_prior.hidden_dim)).cuda() for i in range(max_seq_len): #enc_out, h = self.encoder_prior.forward(inp[i], h, seq_lengths) enc_out, h = self.encoder_prior.forward(inp[i], seq_lengths, h) enc_outs[i] = enc_out encoder_hiddens_x[i] = h[0] ##encoder_outputs, enc_h = self.encoder(inp, inp_lengths.tolist(), None) ''' emb = self.embeddings(inp) emb_shift = torch.zeros_like(emb) # 1 is the index for EOS_TOKEN emb_shift[1:, :, :] = emb[:-1, :, :] # shift the input sentences emb_shift = self.word_drop(emb_shift) encoder_outputs, encoder_hidden = self.encoder_prior( emb, seq_lengths, None) enc_h_x_mean = encoder_outputs.mean(dim=0) # mean pool x: h_f # mean pool x #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # h_f ##################################### # perform reparam trick and get z ##################################### h = encoder_hidden if self.rnn_type == 'LSTM': h = (h[0].cuda(), h[1].cuda()) else: h = h.cuda() mu_prior = self.linear_mu_prior(enc_h_x_mean) log_sigma_prior = self.linear_sigma_prior(enc_h_x_mean) # use the mean (the most representative one) z = mu_prior he = self.linear_z(z) h = h[:self.decoder.n_layers].cuda() ##################################### # Decode ##################################### dec_emb = emb_shift[0] decoder_attentions = torch.zeros(max_seq_len, max_seq_len) sample_type = 0 for i in range(max_seq_len): if dec_type == 'vanilla': out, h = self.decoder.forward(dec_emb, h, None) elif dec_type == 'attn': #out, h, dec_attn = self.decoder.forward(dec_inp, h, encoder_outputs, z) out, h, dec_attn = self.decoder.forward( dec_emb, h, encoder_outputs, None) # decode w/o z padded_attn = F.pad(dec_attn.squeeze(0).squeeze(0), pad=(0, max_seq_len - dec_attn.size(2)), mode='constant', value=EOS_TOKEN) ##decoder_attentions[i,:] += dec_attn.squeeze(0).squeeze(0).cpu().data decoder_attentions[i, :] += padded_attn.cpu().data # 0: argmax if sample_type == 0: dec_inp = out.max(1)[1] dec_emb = self.embeddings(dec_inp) max_val, max_idx = out.data.squeeze().max(0) word_idx = max_idx[0] # 1: tempreture elif sample_type == 1: temperature = 1.0 #1e-2 word_weights = out.squeeze().data.div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] output_word = inputs.vocab.itos[word_idx] out_seq.append(output_word) if word_idx == EOS_TOKEN: break ''' # create an input with the batch_size of 1 dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).cuda() sample_type = 0 for i in range(max_seq_len): if dec_type == 'vanilla': out, h = self.decoder.forward(dec_inp, h, z) elif dec_type == 'attn': out, h, dec_attn = self.decoder.forward(dec_inp, h, enc_outs, he.unsqueeze(0)) # 0: argmax if sample_type == 0: dec_inp = out.max(1)[1] max_val, max_idx = out.data.squeeze().max(0) word_idx = max_idx[0] # 1: tempreture elif sample_type == 1: temperature = 1.0#1e-2 word_weights = out.squeeze().data.div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] output_word = inputs.vocab.itos[word_idx] out_seq.append(output_word) if word_idx == EOS_TOKEN: break ''' #decoder_attentions[:i+1, :len(example)] return out_seq, decoder_attentions[:i + 1, :len(example) - 2]
class AttnGRU_VNMT(nn.Module): """ Pretains attentive GRU for VNMT. """ def __init__(self, rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=0.5, word_dropout=0.5, gpu=True): super(AttnGRU_VNMT, self).__init__() #self.word_dropout = 1.0#0.75 self.word_dropout = word_dropout self.word_drop = nn.Dropout(word_dropout) self.rnn_type = rnn_type self.dec_type = 'attn' self.n_layers = n_layers self.embeddings = nn.Embedding(vocab_size, embedding_dim) # encoder for x self.encoder = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout) # encoder for y #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, # n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True #) ################################################ # Only supports 1-layer decoder for now ################################################ self.decoder = CustomAttnDecoderRNN('CustomGRU', embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=dropout, word_dropout=word_dropout) def batchNLLLoss(self, s, s_lengths, t, t_lengths, device, train=False): loss = 0 batch_size, seq_len = s.size() s_lengths, perm_idx = s_lengths.sort( 0, descending=True) # SORT YOUR TENSORS BY LENGTH! s.data = s.data[perm_idx] t.data = t.data[perm_idx] s = s.permute(1, 0).to(device) # seq_len x batch_size t = t.permute(1, 0).to(device) # seq_len x batch_size emb_s = self.embeddings(s) emb_t = self.embeddings(t) emb_t_shift = torch.zeros_like(emb_t) # 1 is the index for EOS_TOKEN emb_t_shift[1:, :, :] = emb_t[:-1, :, :] # shift the input sentences emb_t_shift = self.word_drop(emb_t_shift) ############################ # Encode x # ############################ # encode x for both the prior model and the poterior model. # linear layers are independent but the encoder to create annotation vectors is shared. #enc_h_x = self.encoder.init_hidden(batch_size).to(device) # (the very first hidden) enc_h_x = None encoder_outputs, encoder_hidden = self.encoder( emb_s, s_lengths, enc_h_x) # torch.Size([12, 250, 256]) enc_h_x_mean = encoder_outputs.mean(0) if self.rnn_type == 'LSTM': enc_h_x = encoder_hidden[0] enc_h = encoder_hidden if self.rnn_type == 'LSTM': dec_h = (enc_h[0][:self.decoder.n_layers].to(device), enc_h[1][:self.decoder.n_layers].to(device)) else: dec_h = enc_h[:self.decoder.n_layers].to(device) ######################################################### # Decode using the last enc_h, context vectors, and z # ######################################################### #dec_s = Variable(torch.LongTensor([[SOS_TOKEN]*batch_size])).long().to(device) #dec_s = dec_s.permute(1, 0) # 128x1 t_length = t.size()[0] all_decoder_outputs = torch.zeros(seq_len, batch_size, self.decoder.vocab_size).to(device) use_target = True #use_target = True if random.random() < self.word_dropout else False for i in range(t_length): if use_target: dec_s = emb_t_shift[i] # shape: batch_size, else: dec_s = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size ])).long().to(device) #out, dec_h = self.decoder.forward(dec_s, dec_h, z) ##out, dec_h, attn_weights = self.decoder.forward(dec_s, dec_h, encoder_outputs, None) # decode w/o z out, dec_h, attn_weights = self.decoder.forward( dec_s, dec_h, encoder_outputs, None) # decode w/o z all_decoder_outputs[i] = out # Compute masked cross entropy loss loss = masked_cross_entropy( # bs x seq_len? all_decoder_outputs.transpose(0, 1).contiguous(), t.transpose(0, 1).contiguous(), t_lengths.to(device)) return loss def generate(self, inputs, ntokens, example, max_seq_len, device, max_words=100): """ Generate example """ print('Generating...') self.encoder.eval() self.decoder.eval() dec_type = self.dec_type out_seq = [] input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(), volatile=True).to(device) for i, wd_idx in enumerate(example): input.data[0][i] = wd_idx input_words = [ inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len) ] # encoder initial h #h = self.encoder.init_hidden(1) # (the very first hidden) inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(), volatile=True) for i in range(max_seq_len): inp.data[0][i] = EOS_TOKEN for i in range(len(example)): inp.data[0][i] = example[i] seq_lengths = torch.LongTensor([ len(x) - list(x).count(1) for x in inp.data.cpu().numpy() ]).to(device) # 1: <pad> inp = inp.permute(1, 0) ############################ # Encode x # ############################ emb = self.embeddings(inp) emb_shift = torch.zeros_like(emb) # 1 is the index for EOS_TOKEN emb_shift[1:, :, :] = emb[:-1, :, :] # shift the input sentences emb_shift = self.word_drop(emb_shift) encoder_outputs, encoder_hidden = self.encoder(emb, seq_lengths, None) #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # mean pool x: h_f ##################################### # perform reparam trick and get z ##################################### h = encoder_hidden if self.rnn_type == 'LSTM': h = (h[0].to(device), h[1].to(device)) else: h = h.to(device) ##################################### # perform reparam trick and get z ##################################### # create an input with the batch_size of 1 #dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).to(device) dec_emb = emb_shift[0] decoder_attentions = torch.zeros(max_seq_len, max_seq_len) sample_type = 0 for i in range(max_seq_len): if dec_type == 'vanilla': out, h = self.decoder.forward(dec_emb, h, None) elif dec_type == 'attn': #out, h, dec_attn = self.decoder.forward(dec_inp, h, encoder_outputs, z) out, h, dec_attn = self.decoder.forward( dec_emb, h, encoder_outputs, None) # decode w/o z padded_attn = F.pad(dec_attn.squeeze(0).squeeze(0), pad=(0, max_seq_len - dec_attn.size(2)), mode='constant', value=EOS_TOKEN) ##decoder_attentions[i,:] += dec_attn.squeeze(0).squeeze(0).cpu().data decoder_attentions[i, :] += padded_attn.cpu().data # 0: argmax if sample_type == 0: dec_inp = out.max(1)[1] dec_emb = self.embeddings(dec_inp) max_val, max_idx = out.data.squeeze().max(0) word_idx = max_idx[0] # 1: tempreture elif sample_type == 1: temperature = 1.0 #1e-2 word_weights = out.squeeze().data.div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] output_word = inputs.vocab.itos[word_idx] out_seq.append(output_word) if word_idx == EOS_TOKEN: break return out_seq, decoder_attentions[:i + 1, :len(example) - 2]
class AttnGRU_VNMT(nn.Module): """ Pretains attentive GRU for VNMT. """ def __init__(self, rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=0.5, word_dropout=None, gpu=True): super(AttnGRU_VNMT, self).__init__() self.word_dropout = 1.0 #0.75 self.rnn_type = rnn_type self.dec_type = 'attn' self.n_layers = n_layers # encoder for x self.encoder = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True) # encoder for y #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len, # n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True #) ################################################ # Only supports 1-layer decoder for now ################################################ self.decoder = CustomAttnDecoderRNN('CustomGRU', embedding_dim, hidden_dim, vocab_size, max_seq_len, n_layers=1, dropout=dropout, word_dropout=word_dropout, gpu=True) def batchNLLLoss(self, inp, target, train=False): loss = 0 batch_size, seq_len = inp.size() inp_lengths = torch.cuda.LongTensor([ len(x) - list(x).count(1) + 1 for x in inp.data.cpu().numpy() ]) # 1: <pad> inp_lengths, perm_idx = inp_lengths.sort( 0, descending=True) # SORT YOUR TENSORS BY LENGTH! # make sure to align the target data along with the sorted input inp.data = inp.data[perm_idx] target.data = target.data[perm_idx] target_lengths = torch.cuda.LongTensor([ len(x) - list(x).count(1) + 1 for x in target.data.cpu().numpy() ]) # 1: <pad> inp = inp.permute(1, 0) # seq_len x batch_size target = target.permute(1, 0) # seq_len x batch_size ############################ # Encode x # ############################ # encode x for both the prior model and the poterior model. # linear layers are independent but the encoder to create annotation vectors is shared. enc_h_x = self.encoder.init_hidden( batch_size) # (the very first hidden) encoder_outputs = Variable( torch.zeros(seq_len, batch_size, self.encoder.hidden_dim)).cuda( ) ## max_len x batch_size x hidden_size encoder_hiddens_x = Variable( torch.zeros(seq_len, self.n_layers, batch_size, self.encoder.hidden_dim)).cuda() for i in range(seq_len): #out, enc_h_x = self.encoder(inp[i], enc_h_x, inp_lengths) # enc_h_x: n_layers, batch_size, hidden_dim out, enc_h_x = self.encoder( inp[i], inp_lengths, enc_h_x) # enc_h_x: n_layers, batch_size, hidden_dim encoder_outputs[i] = out encoder_hiddens_x[i] = enc_h_x if self.rnn_type == 'LSTM': enc_h_x = enc_h_x[0] # mean pool x enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # h_f enc_h = enc_h_x if self.rnn_type == 'LSTM': dec_h = (enc_h[0][:self.decoder.n_layers].cuda(), enc_h[1][:self.decoder.n_layers].cuda()) else: dec_h = enc_h[:self.decoder.n_layers].cuda() ######################################################### # Decode using the last enc_h, context vectors, and z # ######################################################### dec_inp = Variable(torch.LongTensor([[SOS_TOKEN] * batch_size ])).long().cuda() dec_inp = dec_inp.permute(1, 0) # 128x1 target_length = target.size()[0] all_decoder_outputs = Variable( torch.zeros(seq_len, batch_size, self.decoder.vocab_size)).cuda() use_target = True #True if random.random() < self.word_dropout else False for i in range(target_length): #out, dec_h = self.decoder.forward(dec_inp, dec_h, z) out, dec_h, attn_weights = self.decoder.forward( dec_inp, dec_h, encoder_outputs, None) # decode w/o z if use_target: dec_inp = target[i] # shape: batch_size, else: dec_inp = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size ])).long().cuda() all_decoder_outputs[i] = out # apply the objective loss = masked_cross_entropy( # bs x seq_len? all_decoder_outputs.transpose(0, 1).contiguous(), target.transpose(0, 1).contiguous(), Variable(target_lengths)) return loss def generate(self, inputs, ntokens, example, max_seq_len): """ Generate example """ print('Generating...') self.encoder.eval() self.decoder.eval() out_seq = [] dec_type = self.dec_type max_words = 100 input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(), volatile=True) input.data = input.data.cuda() for i, wd_idx in enumerate(example): input.data[0][i] = wd_idx input_words = [ inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len) ] # encoder initial h h = self.encoder.init_hidden(1) # (the very first hidden) inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(), volatile=True) for i in range(max_seq_len): inp.data[0][i] = EOS_TOKEN for i in range(len(example)): inp.data[0][i] = example[i] seq_lengths = torch.cuda.LongTensor([ len(x) - list(x).count(1) for x in inp.data.cpu().numpy() ]) # 1: <pad> inp = inp.permute(1, 0) ############################ # Encode x # ############################ encoder_hiddens_x = Variable( torch.zeros(max_seq_len, self.n_layers, 1, self.encoder.hidden_dim)).cuda() if dec_type == 'vanilla': for i in range(max_seq_len): enc_out, h = self.encoder.forward(inp[i], seq_lengths, h) encoder_hiddens_x[i] = h elif dec_type == 'attn': enc_outs = Variable( torch.zeros(max_seq_len, 1, self.encoder.hidden_dim)).cuda() for i in range(max_seq_len): enc_out, h = self.encoder.forward(inp[i], seq_lengths, h) enc_outs[i] = enc_out encoder_hiddens_x[i] = h # mean pool x #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # h_f ##################################### # perform reparam trick and get z ##################################### if self.rnn_type == 'LSTM': h = (h[0].cuda(), h[1].cuda()) else: h = h.cuda() ##################################### # perform reparam trick and get z ##################################### # create an input with the batch_size of 1 dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).cuda() decoder_attentions = torch.zeros(max_seq_len, max_seq_len) sample_type = 0 for i in range(max_seq_len): if dec_type == 'vanilla': out, h = self.decoder.forward(dec_inp, h, None) elif dec_type == 'attn': #out, h, dec_attn = self.decoder.forward(dec_inp, h, enc_outs, z) out, h, dec_attn = self.decoder.forward( dec_inp, h, enc_outs, None) # decode w/o z decoder_attentions[i, :] += dec_attn.squeeze(0).squeeze( 0).cpu().data # 0: argmax if sample_type == 0: dec_inp = out.max(1)[1] max_val, max_idx = out.data.squeeze().max(0) word_idx = max_idx[0] # 1: tempreture elif sample_type == 1: temperature = 1.0 #1e-2 word_weights = out.squeeze().data.div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] output_word = inputs.vocab.itos[word_idx] out_seq.append(output_word) if word_idx == EOS_TOKEN: #print(EOS_TOKEN) break #print(out_seq) #print('testtest') return out_seq, decoder_attentions[:i + 1, :len(example)]