Esempio n. 1
0
    def __init__(self,
                 rnn_type,
                 embedding_dim,
                 hidden_dim,
                 vocab_size,
                 max_seq_len,
                 n_layers=1,
                 dropout=0.5,
                 word_dropout=0.5,
                 gpu=True):
        super(AttnGRU_VNMT, self).__init__()

        #self.word_dropout = 1.0#0.75
        self.word_dropout = word_dropout
        self.word_drop = nn.Dropout(word_dropout)
        self.rnn_type = rnn_type
        self.dec_type = 'attn'
        self.n_layers = n_layers
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        # encoder for x
        self.encoder = EncoderRNN(rnn_type,
                                  embedding_dim,
                                  hidden_dim,
                                  vocab_size,
                                  max_seq_len,
                                  n_layers=n_layers,
                                  dropout=dropout,
                                  word_dropout=word_dropout)
        # encoder for y
        #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len,
        #   n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True
        #)

        ################################################
        #     Only supports 1-layer decoder for now
        ################################################
        self.decoder = CustomAttnDecoderRNN('CustomGRU',
                                            embedding_dim,
                                            hidden_dim,
                                            vocab_size,
                                            max_seq_len,
                                            n_layers=1,
                                            dropout=dropout,
                                            word_dropout=word_dropout)
Esempio n. 2
0
    def __init__(self,
                 rnn_type,
                 embedding_dim,
                 hidden_dim,
                 vocab_size,
                 max_seq_len,
                 n_layers=1,
                 dropout=0.5,
                 word_dropout=0.5,
                 gpu=True):
        super(VRAE_VNMT, self).__init__()

        self.word_dropout = word_dropout
        self.z_size = 1000  # concat size is absorbed by linear_mu_post etc, so z_size just needs to be equal with hidden_dim
        self.mode = 'vnmt'
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_drop = nn.Dropout(word_dropout)

        self.rnn_type = rnn_type
        self.dec_type = 'attn'
        self.n_layers = n_layers

        self.linear_mu_prior = nn.Linear(
            hidden_dim, self.z_size)  # hidden_dim*1 because we only pass x
        self.linear_sigma_prior = nn.Linear(hidden_dim, self.z_size)
        self.linear_mu_post = nn.Linear(
            hidden_dim * 2,
            self.z_size)  # hidden_dim*2 because we pass x and y
        self.linear_sigma_post = nn.Linear(hidden_dim * 2, self.z_size)

        self.encoder_prior = EncoderRNN(rnn_type,
                                        embedding_dim,
                                        hidden_dim,
                                        vocab_size,
                                        max_seq_len,
                                        n_layers=n_layers,
                                        dropout=dropout,
                                        word_dropout=word_dropout,
                                        gpu=True)
        self.encoder_post = EncoderRNN(rnn_type,
                                       embedding_dim,
                                       hidden_dim,
                                       vocab_size,
                                       max_seq_len,
                                       n_layers=n_layers,
                                       dropout=dropout,
                                       word_dropout=word_dropout,
                                       gpu=True)

        ################################################
        #     Only supports 1-layer decoder for now
        ################################################
        self.decoder = CustomAttnDecoderRNN(
            'CustomGRU',
            embedding_dim,
            hidden_dim,
            vocab_size,
            max_seq_len,
            n_layers=1,
            dropout=dropout,
            word_dropout=word_dropout
        )  # > We use a fixed word dropout rate of 75%

        # for projecting z into the hidden dim of the decoder so that it can be added inside the GRU cells
        self.linear_z = nn.Linear(
            self.z_size, self.decoder.hidden_dim)  # W_z^(2) and b_z^(2)
Esempio n. 3
0
class VRAE_VNMT(nn.Module):
    def __init__(self,
                 rnn_type,
                 embedding_dim,
                 hidden_dim,
                 vocab_size,
                 max_seq_len,
                 n_layers=1,
                 dropout=0.5,
                 word_dropout=0.5,
                 gpu=True):
        super(VRAE_VNMT, self).__init__()

        self.word_dropout = word_dropout
        self.z_size = 1000  # concat size is absorbed by linear_mu_post etc, so z_size just needs to be equal with hidden_dim
        self.mode = 'vnmt'
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_drop = nn.Dropout(word_dropout)

        self.rnn_type = rnn_type
        self.dec_type = 'attn'
        self.n_layers = n_layers

        self.linear_mu_prior = nn.Linear(
            hidden_dim, self.z_size)  # hidden_dim*1 because we only pass x
        self.linear_sigma_prior = nn.Linear(hidden_dim, self.z_size)
        self.linear_mu_post = nn.Linear(
            hidden_dim * 2,
            self.z_size)  # hidden_dim*2 because we pass x and y
        self.linear_sigma_post = nn.Linear(hidden_dim * 2, self.z_size)

        self.encoder_prior = EncoderRNN(rnn_type,
                                        embedding_dim,
                                        hidden_dim,
                                        vocab_size,
                                        max_seq_len,
                                        n_layers=n_layers,
                                        dropout=dropout,
                                        word_dropout=word_dropout,
                                        gpu=True)
        self.encoder_post = EncoderRNN(rnn_type,
                                       embedding_dim,
                                       hidden_dim,
                                       vocab_size,
                                       max_seq_len,
                                       n_layers=n_layers,
                                       dropout=dropout,
                                       word_dropout=word_dropout,
                                       gpu=True)

        ################################################
        #     Only supports 1-layer decoder for now
        ################################################
        self.decoder = CustomAttnDecoderRNN(
            'CustomGRU',
            embedding_dim,
            hidden_dim,
            vocab_size,
            max_seq_len,
            n_layers=1,
            dropout=dropout,
            word_dropout=word_dropout
        )  # > We use a fixed word dropout rate of 75%

        # for projecting z into the hidden dim of the decoder so that it can be added inside the GRU cells
        self.linear_z = nn.Linear(
            self.z_size, self.decoder.hidden_dim)  # W_z^(2) and b_z^(2)

    def reparam_trick(self, mu, log_sigma):
        # the reason of log_sigma: https://www.reddit.com/r/MachineLearning/comments/74dx67/d_why_use_exponential_term_rather_than_log_term/
        epsilon = torch.zeros(self.z_size).cuda()
        epsilon.normal_(0, 1)  # 0 mean unit variance gaussian
        return Variable(epsilon * torch.exp(log_sigma.data * 0.5) + mu.data)

    def vnmt_loss(self, recon_x, target_x, mu_prior, log_sigma_prior, mu_post,
                  log_sigma_post):
        seq_len, batch_size = target_x.size()
        loss_fn = nn.CrossEntropyLoss()
        loss = 0
        for t in range(seq_len):
            loss += loss_fn(recon_x[t], target_x[t])

        total_KLD = 0
        sigma_prior = torch.exp(log_sigma_prior)
        sigma_post = torch.exp(log_sigma_post)

        KLD = ( log_sigma_prior - log_sigma_post + \
            (sigma_post*sigma_post + (mu_post - mu_prior)*(mu_post - mu_prior)) / (2.0*sigma_prior*sigma_prior) - 0.5
        )

        #########################################################
        #  Be careful with the dimension when taking the sum!!!
        #########################################################
        total_KLD += 1.0 * torch.sum(KLD, 1).mean().squeeze()
        return loss, total_KLD

    def batchNLLLoss(self, s, s_lengths, t, t_lengths, device, train=False):
        loss = 0
        batch_size, seq_len = s.size()

        tt = t.clone()
        s_lengths, perm_idx = s_lengths.sort(
            0, descending=True)  # SORT YOUR TENSORS BY LENGTH!
        s.data = s.data[perm_idx]
        t.data = t.data[perm_idx]
        s = s.permute(1, 0).to(device)  # seq_len x batch_size
        t = t.permute(1, 0).to(device)  # seq_len x batch_size

        t_lengths, _perm_idx = t_lengths.sort(
            0, descending=True)  # SORT YOUR TENSORS BY LENGTH!
        tt.data = tt.data[_perm_idx]
        tt = tt.permute(1, 0).to(device)  # seq_len x batch_size

        emb_s = self.embeddings(s)
        emb_t = self.embeddings(t)
        emb_tt = self.embeddings(tt)  # for encoding target
        emb_t_shift = torch.zeros_like(emb_t)  # 1 is the index for EOS_TOKEN
        emb_t_shift[1:, :, :] = emb_t[:-1, :, :]  # shift the input sentences
        emb_t_shift = self.word_drop(emb_t_shift)

        ############################
        #     Encode x and y       #
        ############################
        # encode x for both the prior model and the poterior model.
        # linear layers are independent but the encoder to create annotation vectors is shared.
        enc_h_x = None
        encoder_outputs_x, encoder_hidden_x = self.encoder_prior(
            emb_s, s_lengths, enc_h_x)  # torch.Size([12, 250, 256])
        enc_h_x_mean = encoder_outputs_x.mean(0)
        if self.rnn_type == 'LSTM':
            encoder_hidden = encoder_hidden[0]

        enc_h = encoder_hidden_x
        if self.rnn_type == 'LSTM':
            dec_h = (enc_h[0][:self.decoder.n_layers].to(device),
                     enc_h[1][:self.decoder.n_layers].to(device))
        else:
            dec_h = enc_h[:self.decoder.n_layers].to(device)

        # encode y for both the poterior model.
        #enc_h_y = self.encoder_post.init_hidden(batch_size) # (the very first hidden)
        enc_h_y = None
        encoder_outputs_y, encoder_hidden_y = self.encoder_post(
            emb_tt, t_lengths, enc_h_y)  # torch.Size([12, 250, 256])
        enc_h_y_mean = encoder_outputs_y.mean(0)  # mean pool y

        ############################
        #      Compute Prior       #
        ############################
        #print(enc_h_x_mean.size()) # 250, 6
        mu_prior = self.linear_mu_prior(enc_h_x_mean)
        log_sigma_prior = self.linear_sigma_prior(enc_h_x_mean)

        ############################
        #     Compute Posterior    #
        ############################
        # define these for evaluation times
        mu_post = Variable(torch.zeros(batch_size, self.z_size)).to(device)
        log_sigma_post = Variable(torch.zeros(batch_size,
                                              self.z_size)).to(device)

        # concat h
        enc_h = torch.cat((enc_h_x_mean, enc_h_y_mean), 1)  # h_z' => size:

        # get mu and sigma using the last hidden layer's output
        mu_post = self.linear_mu_post(enc_h)
        log_sigma_post = self.linear_sigma_post(enc_h)

        #####################################
        # perform reparam trick and get z
        #####################################
        # Obtain h_z
        z = self.reparam_trick(mu_post, log_sigma_post)

        ## project z into the decoder's hidden_dim so that it can be added in the GRU cells
        he = self.linear_z(z)

        # Take the last hidden state of the encoder and pass it to the decoder
        dec_h = encoder_hidden_x[:self.decoder.n_layers].to(device)

        ########################################################
        #  Decode using the last enc_h, context vectors, and z
        ########################################################
        #dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]*batch_size])).long().to(device)
        #dec_inp = dec_inp.permute(1, 0) # 128x1
        target_length = t.size()[0]
        all_decoder_outputs = Variable(
            torch.zeros(seq_len, batch_size,
                        self.decoder.vocab_size)).to(device)

        use_target = True  #True if random.random() < self.word_dropout else False
        for i in range(target_length):
            dec_emb = emb_t_shift[i]
            #out, dec_h = self.decoder.forward(dec_inp, dec_h, z)
            #out, dec_h, dec_attn = self.decoder.forward(dec_inp, dec_h, encoder_outputs, he)

            out, dec_h, dec_attn = self.decoder.forward(
                dec_emb, dec_h, encoder_outputs_x, he.unsqueeze(0))
            if use_target:
                #dec_inp = target[i]         # shape: batch_size,
                dec_emb = emb_t_shift[i]
            else:
                dec_inp = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size
                                                     ])).long().to(device)

            all_decoder_outputs[i] = out

        # Compute the VNMT objective
        loss = self.vnmt_loss(all_decoder_outputs, t, mu_prior,
                              log_sigma_prior, mu_post, log_sigma_post)
        return loss

    def sample(self, inp, max_seq_len):
        self.encoder_prior.eval()
        self.decoder.eval()
        pass

    def generate(self, inputs, ntokens, example, max_seq_len):
        """
        Generate example
        """
        batch_size = 1
        self.encoder_prior.eval()
        self.decoder.eval()
        out_seq = []
        dec_type = self.dec_type
        max_words = 100

        input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(),
                         volatile=True)
        input.data = input.data.cuda()
        for i, wd_idx in enumerate(example):
            input.data[0][i] = wd_idx
        input_words = [
            inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len)
        ]

        # encoder initial h
        #h = self.encoder_prior.init_hidden(1) # (the very first hidden)
        inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(),
                       volatile=True)
        for i in range(max_seq_len):
            inp.data[0][i] = EOS_TOKEN
        for i in range(len(example)):
            inp.data[0][i] = example[i]

        seq_lengths = torch.cuda.LongTensor([
            len(x) - list(x).count(1) for x in inp.data.cpu().numpy()
        ])  # 1: <pad>
        inp = inp.permute(1, 0)

        ############################
        #        Encode x             #
        ############################
        '''
        encoder_hiddens_x = Variable(torch.zeros(max_seq_len, batch_size, self.encoder_prior.hidden_dim)).cuda()
        if dec_type == 'vanilla':
            for i in range(max_seq_len):
                #enc_out, h = self.encoder_prior.forward(inp[i], h, seq_lengths)
                enc_out, h = self.encoder_prior.forward(inp[i], seq_lengths, h)
                encoder_hiddens_x[i] = h[0]
        elif dec_type == 'attn':
            enc_outs = Variable(torch.zeros(max_seq_len, 1, self.encoder_prior.hidden_dim)).cuda()
            for i in range(max_seq_len):
                #enc_out, h = self.encoder_prior.forward(inp[i], h, seq_lengths)
                enc_out, h = self.encoder_prior.forward(inp[i], seq_lengths, h)
                enc_outs[i] = enc_out
                encoder_hiddens_x[i] = h[0]
            ##encoder_outputs, enc_h = self.encoder(inp, inp_lengths.tolist(), None)
        '''
        emb = self.embeddings(inp)
        emb_shift = torch.zeros_like(emb)  # 1 is the index for EOS_TOKEN
        emb_shift[1:, :, :] = emb[:-1, :, :]  # shift the input sentences
        emb_shift = self.word_drop(emb_shift)

        encoder_outputs, encoder_hidden = self.encoder_prior(
            emb, seq_lengths, None)
        enc_h_x_mean = encoder_outputs.mean(dim=0)  # mean pool x: h_f
        # mean pool x
        #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # h_f

        #####################################
        # perform reparam trick and get z
        #####################################
        h = encoder_hidden
        if self.rnn_type == 'LSTM':
            h = (h[0].cuda(), h[1].cuda())
        else:
            h = h.cuda()
        mu_prior = self.linear_mu_prior(enc_h_x_mean)
        log_sigma_prior = self.linear_sigma_prior(enc_h_x_mean)

        # use the mean (the most representative one)
        z = mu_prior
        he = self.linear_z(z)
        h = h[:self.decoder.n_layers].cuda()

        #####################################
        #       Decode
        #####################################
        dec_emb = emb_shift[0]
        decoder_attentions = torch.zeros(max_seq_len, max_seq_len)
        sample_type = 0
        for i in range(max_seq_len):
            if dec_type == 'vanilla':
                out, h = self.decoder.forward(dec_emb, h, None)
            elif dec_type == 'attn':
                #out, h, dec_attn = self.decoder.forward(dec_inp, h, encoder_outputs, z)
                out, h, dec_attn = self.decoder.forward(
                    dec_emb, h, encoder_outputs, None)  # decode w/o z
                padded_attn = F.pad(dec_attn.squeeze(0).squeeze(0),
                                    pad=(0, max_seq_len - dec_attn.size(2)),
                                    mode='constant',
                                    value=EOS_TOKEN)

                ##decoder_attentions[i,:] += dec_attn.squeeze(0).squeeze(0).cpu().data
                decoder_attentions[i, :] += padded_attn.cpu().data

            # 0: argmax
            if sample_type == 0:
                dec_inp = out.max(1)[1]
                dec_emb = self.embeddings(dec_inp)
                max_val, max_idx = out.data.squeeze().max(0)
                word_idx = max_idx[0]
            # 1: tempreture
            elif sample_type == 1:
                temperature = 1.0  #1e-2
                word_weights = out.squeeze().data.div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]

            output_word = inputs.vocab.itos[word_idx]
            out_seq.append(output_word)

            if word_idx == EOS_TOKEN:
                break
        '''
        # create an input with the batch_size of 1
        dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).cuda()
        sample_type = 0
        for i in range(max_seq_len):
            if dec_type == 'vanilla':
                out, h = self.decoder.forward(dec_inp, h, z)
            elif dec_type == 'attn':
                out, h, dec_attn = self.decoder.forward(dec_inp, h, enc_outs, he.unsqueeze(0))

            # 0: argmax
            if sample_type == 0:
                dec_inp = out.max(1)[1]
                max_val, max_idx = out.data.squeeze().max(0)
                word_idx = max_idx[0]

            # 1: tempreture
            elif sample_type == 1:
                temperature = 1.0#1e-2
                word_weights = out.squeeze().data.div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]



            output_word = inputs.vocab.itos[word_idx]
            out_seq.append(output_word)

            if word_idx == EOS_TOKEN:
                break
        '''
        #decoder_attentions[:i+1, :len(example)]
        return out_seq, decoder_attentions[:i + 1, :len(example) - 2]
Esempio n. 4
0
class AttnGRU_VNMT(nn.Module):
    """
    Pretains attentive GRU for VNMT.
    """
    def __init__(self,
                 rnn_type,
                 embedding_dim,
                 hidden_dim,
                 vocab_size,
                 max_seq_len,
                 n_layers=1,
                 dropout=0.5,
                 word_dropout=0.5,
                 gpu=True):
        super(AttnGRU_VNMT, self).__init__()

        #self.word_dropout = 1.0#0.75
        self.word_dropout = word_dropout
        self.word_drop = nn.Dropout(word_dropout)
        self.rnn_type = rnn_type
        self.dec_type = 'attn'
        self.n_layers = n_layers
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        # encoder for x
        self.encoder = EncoderRNN(rnn_type,
                                  embedding_dim,
                                  hidden_dim,
                                  vocab_size,
                                  max_seq_len,
                                  n_layers=n_layers,
                                  dropout=dropout,
                                  word_dropout=word_dropout)
        # encoder for y
        #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len,
        #   n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True
        #)

        ################################################
        #     Only supports 1-layer decoder for now
        ################################################
        self.decoder = CustomAttnDecoderRNN('CustomGRU',
                                            embedding_dim,
                                            hidden_dim,
                                            vocab_size,
                                            max_seq_len,
                                            n_layers=1,
                                            dropout=dropout,
                                            word_dropout=word_dropout)

    def batchNLLLoss(self, s, s_lengths, t, t_lengths, device, train=False):
        loss = 0
        batch_size, seq_len = s.size()

        s_lengths, perm_idx = s_lengths.sort(
            0, descending=True)  # SORT YOUR TENSORS BY LENGTH!
        s.data = s.data[perm_idx]
        t.data = t.data[perm_idx]
        s = s.permute(1, 0).to(device)  # seq_len x batch_size
        t = t.permute(1, 0).to(device)  # seq_len x batch_size

        emb_s = self.embeddings(s)
        emb_t = self.embeddings(t)
        emb_t_shift = torch.zeros_like(emb_t)  # 1 is the index for EOS_TOKEN
        emb_t_shift[1:, :, :] = emb_t[:-1, :, :]  # shift the input sentences
        emb_t_shift = self.word_drop(emb_t_shift)

        ############################
        #        Encode x          #
        ############################
        # encode x for both the prior model and the poterior model.
        # linear layers are independent but the encoder to create annotation vectors is shared.
        #enc_h_x = self.encoder.init_hidden(batch_size).to(device) # (the very first hidden)
        enc_h_x = None
        encoder_outputs, encoder_hidden = self.encoder(
            emb_s, s_lengths, enc_h_x)  # torch.Size([12, 250, 256])
        enc_h_x_mean = encoder_outputs.mean(0)
        if self.rnn_type == 'LSTM':
            enc_h_x = encoder_hidden[0]

        enc_h = encoder_hidden
        if self.rnn_type == 'LSTM':
            dec_h = (enc_h[0][:self.decoder.n_layers].to(device),
                     enc_h[1][:self.decoder.n_layers].to(device))
        else:
            dec_h = enc_h[:self.decoder.n_layers].to(device)

        #########################################################
        #  Decode using the last enc_h, context vectors, and z  #
        #########################################################
        #dec_s = Variable(torch.LongTensor([[SOS_TOKEN]*batch_size])).long().to(device)
        #dec_s = dec_s.permute(1, 0) # 128x1
        t_length = t.size()[0]
        all_decoder_outputs = torch.zeros(seq_len, batch_size,
                                          self.decoder.vocab_size).to(device)
        use_target = True
        #use_target = True if random.random() < self.word_dropout else False

        for i in range(t_length):
            if use_target:
                dec_s = emb_t_shift[i]  # shape: batch_size,
            else:
                dec_s = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size
                                                   ])).long().to(device)

            #out, dec_h = self.decoder.forward(dec_s, dec_h, z)
            ##out, dec_h, attn_weights = self.decoder.forward(dec_s, dec_h, encoder_outputs, None) # decode w/o z
            out, dec_h, attn_weights = self.decoder.forward(
                dec_s, dec_h, encoder_outputs, None)  # decode w/o z
            all_decoder_outputs[i] = out

        # Compute masked cross entropy loss
        loss = masked_cross_entropy(  # bs x seq_len?
            all_decoder_outputs.transpose(0, 1).contiguous(),
            t.transpose(0, 1).contiguous(), t_lengths.to(device))
        return loss

    def generate(self,
                 inputs,
                 ntokens,
                 example,
                 max_seq_len,
                 device,
                 max_words=100):
        """
        Generate example
        """
        print('Generating...')
        self.encoder.eval()
        self.decoder.eval()
        dec_type = self.dec_type
        out_seq = []

        input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(),
                         volatile=True).to(device)
        for i, wd_idx in enumerate(example):
            input.data[0][i] = wd_idx
        input_words = [
            inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len)
        ]

        # encoder initial h
        #h = self.encoder.init_hidden(1) # (the very first hidden)
        inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(),
                       volatile=True)
        for i in range(max_seq_len):
            inp.data[0][i] = EOS_TOKEN
        for i in range(len(example)):
            inp.data[0][i] = example[i]

        seq_lengths = torch.LongTensor([
            len(x) - list(x).count(1) for x in inp.data.cpu().numpy()
        ]).to(device)  # 1: <pad>
        inp = inp.permute(1, 0)

        ############################
        #         Encode x         #
        ############################
        emb = self.embeddings(inp)
        emb_shift = torch.zeros_like(emb)  # 1 is the index for EOS_TOKEN
        emb_shift[1:, :, :] = emb[:-1, :, :]  # shift the input sentences
        emb_shift = self.word_drop(emb_shift)

        encoder_outputs, encoder_hidden = self.encoder(emb, seq_lengths, None)
        #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # mean pool x: h_f

        #####################################
        # perform reparam trick and get z
        #####################################
        h = encoder_hidden
        if self.rnn_type == 'LSTM':
            h = (h[0].to(device), h[1].to(device))
        else:
            h = h.to(device)

        #####################################
        # perform reparam trick and get z
        #####################################
        # create an input with the batch_size of 1
        #dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).to(device)
        dec_emb = emb_shift[0]
        decoder_attentions = torch.zeros(max_seq_len, max_seq_len)
        sample_type = 0
        for i in range(max_seq_len):
            if dec_type == 'vanilla':
                out, h = self.decoder.forward(dec_emb, h, None)
            elif dec_type == 'attn':
                #out, h, dec_attn = self.decoder.forward(dec_inp, h, encoder_outputs, z)
                out, h, dec_attn = self.decoder.forward(
                    dec_emb, h, encoder_outputs, None)  # decode w/o z
                padded_attn = F.pad(dec_attn.squeeze(0).squeeze(0),
                                    pad=(0, max_seq_len - dec_attn.size(2)),
                                    mode='constant',
                                    value=EOS_TOKEN)

                ##decoder_attentions[i,:] += dec_attn.squeeze(0).squeeze(0).cpu().data
                decoder_attentions[i, :] += padded_attn.cpu().data

            # 0: argmax
            if sample_type == 0:
                dec_inp = out.max(1)[1]
                dec_emb = self.embeddings(dec_inp)
                max_val, max_idx = out.data.squeeze().max(0)
                word_idx = max_idx[0]
            # 1: tempreture
            elif sample_type == 1:
                temperature = 1.0  #1e-2
                word_weights = out.squeeze().data.div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]

            output_word = inputs.vocab.itos[word_idx]
            out_seq.append(output_word)

            if word_idx == EOS_TOKEN:
                break

        return out_seq, decoder_attentions[:i + 1, :len(example) - 2]
Esempio n. 5
0
class AttnGRU_VNMT(nn.Module):
    """
	Pretains attentive GRU for VNMT.
	"""
    def __init__(self,
                 rnn_type,
                 embedding_dim,
                 hidden_dim,
                 vocab_size,
                 max_seq_len,
                 n_layers=1,
                 dropout=0.5,
                 word_dropout=None,
                 gpu=True):
        super(AttnGRU_VNMT, self).__init__()

        self.word_dropout = 1.0  #0.75
        self.rnn_type = rnn_type
        self.dec_type = 'attn'
        self.n_layers = n_layers

        # encoder for x
        self.encoder = EncoderRNN(rnn_type,
                                  embedding_dim,
                                  hidden_dim,
                                  vocab_size,
                                  max_seq_len,
                                  n_layers=n_layers,
                                  dropout=dropout,
                                  word_dropout=word_dropout,
                                  gpu=True)
        # encoder for y
        #self.encoder_post = EncoderRNN(rnn_type, embedding_dim, hidden_dim, vocab_size, max_seq_len,
        #	n_layers=n_layers, dropout=dropout, word_dropout=word_dropout, gpu=True
        #)

        ################################################
        #     Only supports 1-layer decoder for now
        ################################################
        self.decoder = CustomAttnDecoderRNN('CustomGRU',
                                            embedding_dim,
                                            hidden_dim,
                                            vocab_size,
                                            max_seq_len,
                                            n_layers=1,
                                            dropout=dropout,
                                            word_dropout=word_dropout,
                                            gpu=True)

    def batchNLLLoss(self, inp, target, train=False):
        loss = 0
        batch_size, seq_len = inp.size()

        inp_lengths = torch.cuda.LongTensor([
            len(x) - list(x).count(1) + 1 for x in inp.data.cpu().numpy()
        ])  # 1: <pad>
        inp_lengths, perm_idx = inp_lengths.sort(
            0, descending=True)  # SORT YOUR TENSORS BY LENGTH!
        # make sure to align the target data along with the sorted input
        inp.data = inp.data[perm_idx]
        target.data = target.data[perm_idx]
        target_lengths = torch.cuda.LongTensor([
            len(x) - list(x).count(1) + 1 for x in target.data.cpu().numpy()
        ])  # 1: <pad>
        inp = inp.permute(1, 0)  # seq_len x batch_size
        target = target.permute(1, 0)  # seq_len x batch_size

        ############################
        #        Encode x          #
        ############################
        # encode x for both the prior model and the poterior model.
        # linear layers are independent but the encoder to create annotation vectors is shared.
        enc_h_x = self.encoder.init_hidden(
            batch_size)  # (the very first hidden)
        encoder_outputs = Variable(
            torch.zeros(seq_len, batch_size, self.encoder.hidden_dim)).cuda(
            )  ## max_len x batch_size x hidden_size
        encoder_hiddens_x = Variable(
            torch.zeros(seq_len, self.n_layers, batch_size,
                        self.encoder.hidden_dim)).cuda()
        for i in range(seq_len):
            #out, enc_h_x = self.encoder(inp[i], enc_h_x, inp_lengths) # enc_h_x: n_layers, batch_size, hidden_dim
            out, enc_h_x = self.encoder(
                inp[i], inp_lengths,
                enc_h_x)  # enc_h_x: n_layers, batch_size, hidden_dim
            encoder_outputs[i] = out
            encoder_hiddens_x[i] = enc_h_x
        if self.rnn_type == 'LSTM':
            enc_h_x = enc_h_x[0]
        # mean pool x
        enc_h_x_mean = encoder_hiddens_x.mean(dim=0)  # h_f

        enc_h = enc_h_x
        if self.rnn_type == 'LSTM':
            dec_h = (enc_h[0][:self.decoder.n_layers].cuda(),
                     enc_h[1][:self.decoder.n_layers].cuda())
        else:
            dec_h = enc_h[:self.decoder.n_layers].cuda()

        #########################################################
        #  Decode using the last enc_h, context vectors, and z  #
        #########################################################
        dec_inp = Variable(torch.LongTensor([[SOS_TOKEN] * batch_size
                                             ])).long().cuda()
        dec_inp = dec_inp.permute(1, 0)  # 128x1
        target_length = target.size()[0]
        all_decoder_outputs = Variable(
            torch.zeros(seq_len, batch_size, self.decoder.vocab_size)).cuda()

        use_target = True  #True if random.random() < self.word_dropout else False
        for i in range(target_length):
            #out, dec_h = self.decoder.forward(dec_inp, dec_h, z)
            out, dec_h, attn_weights = self.decoder.forward(
                dec_inp, dec_h, encoder_outputs, None)  # decode w/o z
            if use_target:
                dec_inp = target[i]  # shape: batch_size,
            else:
                dec_inp = Variable(torch.LongTensor([[UNK_TOKEN] * batch_size
                                                     ])).long().cuda()

            all_decoder_outputs[i] = out

        # apply the objective
        loss = masked_cross_entropy(  # bs x seq_len?
            all_decoder_outputs.transpose(0, 1).contiguous(),
            target.transpose(0, 1).contiguous(), Variable(target_lengths))

        return loss

    def generate(self, inputs, ntokens, example, max_seq_len):
        """
		Generate example
		"""

        print('Generating...')
        self.encoder.eval()
        self.decoder.eval()
        out_seq = []
        dec_type = self.dec_type
        max_words = 100

        input = Variable(torch.rand(1, max_seq_len).mul(ntokens).long(),
                         volatile=True)
        input.data = input.data.cuda()
        for i, wd_idx in enumerate(example):
            input.data[0][i] = wd_idx
        input_words = [
            inputs.vocab.itos[input.data[0][i]] for i in range(0, max_seq_len)
        ]

        # encoder initial h
        h = self.encoder.init_hidden(1)  # (the very first hidden)
        inp = Variable(torch.rand(1, max_seq_len).mul(ntokens).long().cuda(),
                       volatile=True)
        for i in range(max_seq_len):
            inp.data[0][i] = EOS_TOKEN
        for i in range(len(example)):
            inp.data[0][i] = example[i]

        seq_lengths = torch.cuda.LongTensor([
            len(x) - list(x).count(1) for x in inp.data.cpu().numpy()
        ])  # 1: <pad>
        inp = inp.permute(1, 0)

        ############################
        #         Encode x         #
        ############################
        encoder_hiddens_x = Variable(
            torch.zeros(max_seq_len, self.n_layers, 1,
                        self.encoder.hidden_dim)).cuda()
        if dec_type == 'vanilla':
            for i in range(max_seq_len):
                enc_out, h = self.encoder.forward(inp[i], seq_lengths, h)
                encoder_hiddens_x[i] = h
        elif dec_type == 'attn':
            enc_outs = Variable(
                torch.zeros(max_seq_len, 1, self.encoder.hidden_dim)).cuda()
            for i in range(max_seq_len):
                enc_out, h = self.encoder.forward(inp[i], seq_lengths, h)
                enc_outs[i] = enc_out
                encoder_hiddens_x[i] = h

        # mean pool x
        #enc_h_x_mean = encoder_hiddens_x.mean(dim=0) # h_f

        #####################################
        # perform reparam trick and get z
        #####################################
        if self.rnn_type == 'LSTM':
            h = (h[0].cuda(), h[1].cuda())
        else:
            h = h.cuda()

        #####################################
        # perform reparam trick and get z
        #####################################
        # create an input with the batch_size of 1
        dec_inp = Variable(torch.LongTensor([[SOS_TOKEN]])).cuda()
        decoder_attentions = torch.zeros(max_seq_len, max_seq_len)
        sample_type = 0
        for i in range(max_seq_len):
            if dec_type == 'vanilla':
                out, h = self.decoder.forward(dec_inp, h, None)
            elif dec_type == 'attn':
                #out, h, dec_attn = self.decoder.forward(dec_inp, h, enc_outs, z)
                out, h, dec_attn = self.decoder.forward(
                    dec_inp, h, enc_outs, None)  # decode w/o z
                decoder_attentions[i, :] += dec_attn.squeeze(0).squeeze(
                    0).cpu().data

            # 0: argmax
            if sample_type == 0:
                dec_inp = out.max(1)[1]
                max_val, max_idx = out.data.squeeze().max(0)
                word_idx = max_idx[0]

            # 1: tempreture
            elif sample_type == 1:
                temperature = 1.0  #1e-2
                word_weights = out.squeeze().data.div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]

            output_word = inputs.vocab.itos[word_idx]
            out_seq.append(output_word)

            if word_idx == EOS_TOKEN:
                #print(EOS_TOKEN)
                break

        #print(out_seq)
        #print('testtest')
        return out_seq, decoder_attentions[:i + 1, :len(example)]