示例#1
0
文件: vae2.py 项目: seanxu1015/vae
def make_decoder(z, x_shape=(1, 20, 1)):
    '''
    Decoder: p(x|z)
    '''
    net = make_nn(z, 20)
    logits = tf.reshape(net, tf.concat([[-1], x_shape], axis=0))
    return tfd.Independent(tfd.Bernoulli(logits))
示例#2
0
 def get_inference_dist(n_dims):
     with tf.name_scope('InferenceDistribution'):
         loc = tf.get_variable('loc', [n_dims], 'float32')
         scale = tf.get_variable('scale', [n_dims], 'float32')
         inference_dist = tfd.Independent(
             tfd.NormalWithSoftplusScale(loc, scale))
     return inference_dist
示例#3
0
def _make_decoder(code, data_shape):
    with tf.variable_scope('decoder'):
        x = code
        x = tf.layers.dense(x, 200, tf.nn.relu)
        x = tf.layers.dense(x, 200, tf.nn.relu)
        logit = tf.layers.dense(x, _prod(data_shape))
        logit = tf.reshape(logit, [-1] + data_shape)
        return tfd.Independent(tfd.Bernoulli(logit), 2)
示例#4
0
    def get_log_p_2(n_dims):
        dist = tfd.Independent(
            tfd.Gamma(1.2 * tf.ones([n_dims]), 1.0 * tf.ones([n_dims])))

        def log_p(x):
            return dist.log_prob(tf.exp(x))

        return log_p
示例#5
0
    def _build(self,
               inputs,
               nb_samples=10,
               seed=0,
               encoder_param_type='natural'):
        ### vae encode
        emb = self._encoder(inputs)
        enc_eta1 = self._mu_net(emb)
        enc_eta2_diag = self._sigma_net(emb)
        if encoder_param_type == 'natural':
            enc_eta2_diag *= -1. / 2
            # enc_eta2_diag -= 1e-8
        enc_eta2 = tf.matrix_diag(enc_eta2_diag)

        ### GMM natural parameters
        gmm_pi, gmm_eta1, gmm_eta2 = self.phi_gmm()

        ### combined GMM and VAE latent parameters
        # eta1_tilde.shape = (N, K, D); eta2_tsilde.shape = (N, K, D, D)
        # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * enc_eta2)]):
        eta1_tilde = tf.expand_dims(
            enc_eta1, axis=1) + tf.expand_dims(
                gmm_eta1, axis=0)
        eta2_tilde = tf.expand_dims(
            enc_eta2, axis=1) + tf.expand_dims(
                gmm_eta2, axis=0)
        log_z_given_y_phi = compute_log_z_given_y(enc_eta1, enc_eta2, gmm_eta1,
                                                  gmm_eta2, gmm_pi)
        # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * gmm_eta2)]):
        mu, cov = gaussian.natural_to_standard(eta1_tilde, eta2_tilde)
        posterior_mixture_distribution = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(tf.exp(log_z_given_y_phi)),
            components_distribution=tfd.MultivariateNormalFullCovariance(
                loc=mu, covariance_matrix=cov))

        # sample x for each of the K components
        # latent_k_samples.shape == nb_samples, batch_size, nb_components, latent_dim
        latent_k_samples = posterior_mixture_distribution.components_distribution.sample(
            [nb_samples])

        ### vae decode
        output_mean = snt.BatchApply(self._decoder, n_dims=3)(latent_k_samples)
        output_variance = tf.get_variable(
            'output_variance',
            dtype=tf.float32,
            initializer=tf.zeros(output_mean.get_shape().as_list()),
            trainable=True)  # learned parameter for output distribution
        output_distribution = tfd.Independent(
            tfd.MultivariateNormalDiagWithSoftplusScale(
                loc=output_mean, scale_diag=output_variance),
            reinterpreted_batch_ndims=2)

        # subsample for each datum in minibatch (go from `nb_samples` per component to `nb_samples` total)
        latent_samples = subsample_x(
            tf.transpose(latent_k_samples, [1, 0, 2, 3]), log_z_given_y_phi,
            seed)

        return output_distribution, posterior_mixture_distribution, latent_k_samples, latent_samples, log_z_given_y_phi
示例#6
0
def make_decoder(code, data_shape):

    x = code
    x = tf.layers.dense(x, hidden, tf.nn.relu)
    x = tf.layers.dense(x, hidden, tf.nn.relu)
    logit = tf.layers.dense(x, np.prod(data_shape))
    logit = tf.reshape(logit, [-1] + data_shape)

    return tfd.Independent(tfd.Bernoulli(logit), 2)
示例#7
0
def make_decoder(z, x_shape=(x_dim,)):
    '''
    Decoder: p(x|z)
    '''
    with tf.variable_scope("decoder"):
        net = make_nn(z, x_dim)
        print('decoder net', net)
        logits = tf.reshape(net, tf.concat([[nb_z_samples, -1], x_shape], axis=0))  # For the batch
        print('logits', logits)
        return tfd.Independent(tfd.Bernoulli(logits), reinterpreted_batch_ndims=1)
示例#8
0
    def _build(self, inputs, hvar_labels, n_samples=10, analytic_kl=True):
        datum_shape = inputs.get_shape().as_list()[1:]
        enc_repr = self._encoder(inputs)

        self.hvar_prior = tfd.ExpRelaxedOneHotCategorical(
            temperature=self._temperature, logits=hvar_labels)
        self.hvar_posterior = tfd.ExpRelaxedOneHotCategorical(
            temperature=self._temperature, logits=self._hvar(enc_repr))
        hvar_sample_shape = [n_samples
                             ] + self.hvar_posterior.batch_shape.as_list(
                             ) + self.hvar_posterior.event_shape.as_list()
        hvar_sample = tf.reshape(self.hvar_posterior.sample(n_samples),
                                 hvar_sample_shape)

        self.latent_posterior = self._latent_posterior_fn(
            self._loc(enc_repr), self._scale(enc_repr))
        latent_posterior_sample = self.latent_posterior.sample(n_samples)

        joint_sample = tf.concat([hvar_sample, latent_posterior_sample],
                                 axis=-1)

        sample_decoder = snt.BatchApply(self._decoder)
        self.output_distribution = tfd.Independent(
            tfd.Bernoulli(logits=sample_decoder(joint_sample)),
            reinterpreted_batch_ndims=len(datum_shape))

        distortion = -self.output_distribution.log_prob(inputs)
        if analytic_kl and n_samples == 1:
            rate = tfd.kl_divergence(self.latent_posterior, self.latent_prior)
        else:
            rate = (self.latent_posterior.log_prob(latent_posterior_sample) -
                    self.latent_prior.log_prob(latent_posterior_sample))
        hrate = self.hvar_posterior.log_prob(
            hvar_sample) - self.hvar_prior.log_prob(hvar_sample)
        # hrate = tf.Print(hrate, [temperature])
        # hrate = tf.Print(hrate, [hvar_sample], summarize=10)
        # hrate = tf.Print(hrate, [self.hvar_posterior.log_prob(hvar_sample)])
        # hrate = tf.Print(hrate, [self.hvar_prior.log_prob(hvar_sample)])
        # hrate = tf.Print(hrate, [hrate], summarize=10)
        elbo_local = -(rate + hrate + distortion)
        self.elbo = tf.reduce_mean(elbo_local)
        self.importance_weighted_elbo = tf.reduce_mean(
            tf.reduce_logsumexp(elbo_local, axis=0) -
            tf.log(tf.to_float(n_samples)))

        self.hvar_sample = tf.exp(tf.split(hvar_sample, n_samples)[0])
        self.hvar_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=hvar_labels, logits=tf.split(hvar_sample, n_samples)[0])
        self.hvar_labels = hvar_labels
        self.distortion = distortion
        self.rate = rate
        self.hrate = hrate
示例#9
0
    def _reconstruction_loss(self):
        # E[log p(x|z)]
        if self.loss_function == 'cross_entropy':
            d_p_x_z = tf.distributions.Bernoulli(self.decoder)
            d_p_x_z = tfcd.Independent(d_p_x_z, 3)
            log_likelihood = d_p_x_z.log_prob(self.x)
            log_likelihood = tf.reduce_mean(log_likelihood, name='log_likelihood')
            # log_likelihood = tf.div(log_likelihood, tf.reduce_prod(self.x.shape))
        elif self.loss_function == 'mse':
            log_likelihood = tf.reduce_sum(tf.squared_difference(self.x, self.decoder), axis=(1, 2, 3))
        else:
            raise NotImplementedError()

        return log_likelihood
 def make_decoder(self, z):
     x = tf.layers.dense(z, 200, tf.nn.relu)
     x = tf.layers.dense(x, 200, tf.nn.relu)
     logits = tf.layers.dense(x, self.m_variants)
     return tfd.Independent(tfd.Binomial(logits=logits, total_count=2.),
                            reinterpreted_batch_ndims=1)
示例#11
0
 def get_log_p_1(n_dims):
     target_dist = tfd.Independent(
         tfd.NormalWithSoftplusScale(loc=tf.zeros(n_dims),
                                     scale=10 * tf.ones(n_dims)))
     return target_dist.log_prob