예제 #1
0
    def generate(self, context, sentence_length, n_context):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context
        context_hidden=None
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            try:
                encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])
            except IndexError:
                print(context.shape)
                sys.exit(-1)

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)
            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)
            """
            mu_prior, var_prior = self.prior(context_outputs)
            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            """
            alpha_prior = self.prior(context_outputs)
            if torch.cuda.is_available():
                alpha_prior = alpha_prior.cpu()
            dirichlet_dist = Dirichlet(alpha_prior)
            z_sent = dirichlet_dist.rsample()
            if torch.cuda.is_available():
                z_sent = z_sent.cuda()
            if self.config.mode == 'generate' and self.config.one_latent_z is not None:
                print('Generated z_sent: '+str(z_sent))
                z_sent = [[0.0 for i in range(self.config.z_sent_size)]]
                z_sent[0][self.config.one_latent_z] = 1.0
                z_sent = torch.tensor(z_sent).cuda()
                print('We use z_sent: '+str(z_sent))
            
            latent_context = torch.cat([context_outputs, z_sent], 1)
            decoder_init = self.context2decoder(latent_context)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            if self.config.sample:
                prediction = self.decoder(None, decoder_init)
                p = prediction.data.cpu().numpy()
                length = torch.from_numpy(np.where(p == EOS_ID)[1])
            else:
                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
                # prediction: [batch_size, seq_len]
                prediction = prediction[:, 0, :]
                # length: [batch_size]
                length = [l[0] for l in length]
                length = to_var(torch.LongTensor(length))

            samples.append(prediction)

            encoder_outputs, encoder_hidden = self.encoder(prediction,
                                                           length)

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)

            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        samples = torch.stack(samples, 1)
        return samples
예제 #2
0
def main():
    args = parse_arguments()

    x, c, y, S, antes = load_data(args, 'support2')

    print(antes[:5])
    print(type(antes))
    print(len(antes))

    # create model
    args['n_features'] = c.shape[-1]

    model = create_model(args, antes, S.shape[0])

    # use whole dataset for now
    S = torch.tensor(S, dtype=torch.float)
    c = torch.tensor(c, dtype=torch.float)
    n_train = S.shape[0]
    print(n_train)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    # optimizer = optim.Adam(model.parameters(), lr=0.001)

    start_time = time.time()
    for ep in range(50):
        optimizer.zero_grad()

        d_soft = model(c, S)

        # print(d)
        maxes, argmaxes = d_soft.max(dim=-1)

        d = [antes[amax] for amax in argmaxes]
        print(d)

        n_classes = y.shape[-1]
        B = torch.zeros((n_train, n_classes, len(d) + 1))
        for i, xi in enumerate(x):
            for j, lhs in enumerate(d):
                if set(lhs).issubset(xi):
                    B[i, y[i, 0], j] = 1.
                    break

            B[i, y[i, 0], -1] = 1 - B[i, y[i, 0], :-1].sum()

        assert B.sum() == S.size()[0]

        # get dirichlet prior
        alpha = 1.
        priors = alpha + B.sum(0)

        thetas = torch.zeros((len(d) + 1, n_classes))
        for i in range(len(d) + 1):
            p_theta = Dirichlet(torch.tensor(priors[:, i]))
            thetas[i] = p_theta.rsample()

        # compute p(y | d)
        log_py = 0
        for i, yi in enumerate(y):
            for j in range(len(d) + 1):
                log_py += B[i, y[i, 0], j] * torch.log(thetas[j, y[i, 0]])

        # compute p(d | input), as p(d_1 | input) p(d_2 | d_1, input) ...
        log_pd = maxes.sum()

        log_prob = -(log_py + log_pd)

        elapsed = time.time() - start_time
        print(
            f"Epoch {ep}: log-prob: {log_prob:.2f}, log p(d|x): {log_pd:.2f}, ",
            end='')
        print(f"log p(y|d): {log_py:.2f} (Elapsed: {elapsed:.2f}s)")

        log_prob.backward()
        optimizer.step()
예제 #3
0
    def forward(self, sentences, sentence_length,
                input_conversation_length, target_sentences, decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat(
            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input,
                                                                    input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        alpha_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not decode:
            alpha_posterior = self.posterior(
                context_outputs, encoder_hidden_inference_flat)

            # resample of dirichlet
            # z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            if torch.cuda.is_available():
                alpha_posterior = alpha_posterior.cpu()
            
            dirichlet_dist = Dirichlet(alpha_posterior)
            z_sent = dirichlet_dist.rsample()
            if torch.cuda.is_available():
                z_sent = to_var(z_sent)
                alpha_posterior = to_var(alpha_posterior)

            # this two variable log_q_zx and log_p_z is not necessary here
            # log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()
            # log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # log_q_zx = dirichlet_logpdf(z_sent, alpha_posterior).sum()
            # log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum()
            # print(" ")
            log_q_zx = dirichlet_dist.log_prob(z_sent.cpu()).sum().cuda()
            log_p_z = Dirichlet(alpha_prior.cpu()).log_prob(z_sent.cpu()).sum().cuda()
            # print(log_q_zx.item(), " ", post_z.item())
            # print(log_p_z.item(), " ", prior_z.item())
            # kl_div: [num_sentneces]
            # kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior)
            kl_div = dirichlet_kl_div(alpha_posterior, alpha_prior)
            kl_div = torch.sum(kl_div)
        else:
            # z_sent = mu_prior + torch.sqrt(var_prior) * eps
            if torch.cuda.is_available():
                alpha_prior = alpha_prior.cpu()
            dirichlet_dist = Dirichlet(alpha_prior)
            z_sent = dirichlet_dist.rsample()
            if torch.cuda.is_available():
                z_sent = z_sent.cuda()
                alpha_prior = alpha_prior.cuda()
            
            kl_div = None
            # log_p_z = dirichlet_logpdf(z_sent, mu_prior, var_prior).sum()
            log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum()
            log_q_zx = None
        
        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)

            return prediction, kl_div, log_p_z, log_q_zx