Пример #1
0
 def logprobs_all(self, num_samples):
     z = sample_from_bernoulli(
         tf.constant(.5,
                     shape=(num_samples, self.dims[0]),
                     dtype=tf.float32))
     mu = self.inference(z)
     h = gumbel_sigmoid_sample(mu, self.temp, self.hard)
     logp_z = bernoulli_log_likelihood(
         z, tf.constant(.5, shape=z.shape, dtype=tf.float32))
     logp_h_given_z = bernoulli_log_likelihood(h, mu)
     return logp_z + logp_h_given_z, z, h
Пример #2
0
    def loss(self, outputs, weight):
        c_outputs, z_outputs, x_outputs = outputs

        # 1. Reconstruction loss
        x, x_mean = x_outputs
        recon_loss = bernoulli_log_likelihood(x.view(-1, 1, 28, 28), x_mean)
        recon_loss /= (self.batch_size * self.sample_size)

        # 2. KL Divergence terms
        kl = 0

        # a) Context divergence
        c_mean, c_logvar = c_outputs
        kl_c = kl_diagnormal_stdnormal(c_mean, c_logvar)
        kl += kl_c

        # b) Latent divergences
        qz_params, pz_params = z_outputs
        shapes = ((self.batch_size, self.sample_size, self.z_dim),
                  (self.batch_size, 1, self.z_dim))
        for i in range(self.n_stochastic):
            args = (qz_params[i][0].view(shapes[0]),
                    qz_params[i][1].view(shapes[0]),
                    pz_params[i][0].view(shapes[1] if i == 0 else shapes[0]),
                    pz_params[i][1].view(shapes[1] if i == 0 else shapes[0]))
            kl_z = kl_diagnormal_diagnormal(*args)
            kl += kl_z

        kl /= (self.batch_size * self.sample_size)

        # Variational lower bound and weighted loss
        vlb = recon_loss - kl
        loss = -((weight * recon_loss) - (kl / weight))

        return loss, vlb
Пример #3
0
 def logprobs(self, x):
     # log \sum_k pi_k p(x|mu_k)
     # log sum exp_k (log pi_k + log p(x|mu_k))
     log_pi = self.pi_logits - tf.reduce_logsumexp(self.pi_logits, axis=1)
     logp_x_given_mu_all = bernoulli_log_likelihood(
         tf.expand_dims(x, 1), tf.expand_dims(self.mus, 0))
     return tf.reduce_logsumexp(log_pi + logp_x_given_mu_all, axis=1)
Пример #4
0
    def logprobs(self, h):
        assert (self.dims[0] <= 15)
        z_all = np.arange(2**self.dims[0], dtype=np.int32)
        z_all = ((z_all.reshape(-1, 1) &
                  (2**np.arange(self.dims[0]))) != 0).astype(np.float32)
        z_all = tf.constant(z_all[:, ::-1], dtype=tf.float32)
        h_mu_all = self.inference(z_all)

        logp_z_all = bernoulli_log_likelihood(
            z_all,
            tf.constant(.5,
                        shape=(2**self.dims[0], self.dims[0]),
                        dtype=tf.float32))
        logp_h_given_z_all = bernoulli_log_likelihood(
            tf.expand_dims(h, 1), tf.expand_dims(h_mu_all, 0))
        logp_h = tf.reduce_logsumexp(tf.expand_dims(logp_z_all, 0) +
                                     logp_h_given_z_all,
                                     axis=1)

        return logp_h
Пример #5
0
    def _logprob_v(self, h):
        assert (self.dim_h2v[0] <= 15)
        assert (self.has_h == False)
        z_all = np.arange(2**self.dim_h2v[0], dtype=np.int32)
        z_all = ((z_all.reshape(-1, 1) &
                  (2**np.arange(self.dim_h2v[0]))) != 0).astype(np.float32)
        z_all = tf.constant(z_all[:, ::-1], dtype=tf.float32)
        h_mu_all = self.h2v(z_all)
        #print z_all.numpy().shape, h_mu_all.numpy().shape

        logp_z_all = bernoulli_log_likelihood(
            z_all,
            tf.constant(.5,
                        shape=(2**self.dim_h2v[0], self.dim_h2v[0]),
                        dtype=tf.float32))
        #print logp_z_all.numpy().shape
        logp_h_given_z_all = bernoulli_log_likelihood(
            tf.expand_dims(h, 1), tf.expand_dims(h_mu_all, 0))
        #print logp_h_given_z_all.numpy().shape
        logp_h = tf.reduce_logsumexp(tf.expand_dims(logp_z_all, 0) +
                                     logp_h_given_z_all,
                                     axis=1)
        return logp_h
 def log_conditional_prob_evaluate(self, x, h):
     # E_{Q(h|x)} log Q(h|x)
     h_mu = self.inference(x)
     logp_h_given_x = bernoulli_log_likelihood(h, h_mu)
     return logp_h_given_x
 def log_conditional_prob(self, x):
     # E_{Q(h|x)} log Q(h|x)
     h_mu = self.inference(x)
     h = gumbel_sigmoid_sample(h_mu, self.temp, self.hard)
     logp_h_given_x = bernoulli_log_likelihood(h, h_mu)
     return logp_h_given_x, h
Пример #8
0
 def logprob_v_and_h(self, h, v):
     logp_h = self.prior_h.logprobs(h)
     logp_v_given_h = bernoulli_log_likelihood(v, self.h2v(h))
     return logp_v_given_h + logp_h
Пример #9
0
 def logprob_v_and_h_and_z(self, v, h, z):
     logp_z = bernoulli_log_likelihood(
         z, tf.constant(.5, shape=z.shape, dtype=tf.float32))
     logp_h_given_z = bernoulli_log_likelihood(h, self.z2h(z))
     logp_v_given_h = bernoulli_log_likelihood(v, self.h2v(h))
     return logp_v_given_h + logp_h_given_z + logp_z
Пример #10
0
 def logprob_v_and_h_approx(self, h, v):
     h_mu = tf.reduce_mean(h, axis=0)
     logp_h = bernoulli_log_likelihood(h, h_mu)
     logp_v_given_h = bernoulli_log_likelihood(v, self.h2v(h))
     return logp_v_given_h + logp_h
Пример #11
0
 def log_conditional_prob_with_sampels(self, x, h):
     h_mu = self.inference(x)
     logp_h_given_x = bernoulli_log_likelihood(h, h_mu)
     return logp_h_given_x