def make_decoder(z, x_shape=(1, 20, 1)): ''' Decoder: p(x|z) ''' net = make_nn(z, 20) logits = tf.reshape(net, tf.concat([[-1], x_shape], axis=0)) return tfd.Independent(tfd.Bernoulli(logits))
def get_inference_dist(n_dims): with tf.name_scope('InferenceDistribution'): loc = tf.get_variable('loc', [n_dims], 'float32') scale = tf.get_variable('scale', [n_dims], 'float32') inference_dist = tfd.Independent( tfd.NormalWithSoftplusScale(loc, scale)) return inference_dist
def _make_decoder(code, data_shape): with tf.variable_scope('decoder'): x = code x = tf.layers.dense(x, 200, tf.nn.relu) x = tf.layers.dense(x, 200, tf.nn.relu) logit = tf.layers.dense(x, _prod(data_shape)) logit = tf.reshape(logit, [-1] + data_shape) return tfd.Independent(tfd.Bernoulli(logit), 2)
def get_log_p_2(n_dims): dist = tfd.Independent( tfd.Gamma(1.2 * tf.ones([n_dims]), 1.0 * tf.ones([n_dims]))) def log_p(x): return dist.log_prob(tf.exp(x)) return log_p
def _build(self, inputs, nb_samples=10, seed=0, encoder_param_type='natural'): ### vae encode emb = self._encoder(inputs) enc_eta1 = self._mu_net(emb) enc_eta2_diag = self._sigma_net(emb) if encoder_param_type == 'natural': enc_eta2_diag *= -1. / 2 # enc_eta2_diag -= 1e-8 enc_eta2 = tf.matrix_diag(enc_eta2_diag) ### GMM natural parameters gmm_pi, gmm_eta1, gmm_eta2 = self.phi_gmm() ### combined GMM and VAE latent parameters # eta1_tilde.shape = (N, K, D); eta2_tsilde.shape = (N, K, D, D) # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * enc_eta2)]): eta1_tilde = tf.expand_dims( enc_eta1, axis=1) + tf.expand_dims( gmm_eta1, axis=0) eta2_tilde = tf.expand_dims( enc_eta2, axis=1) + tf.expand_dims( gmm_eta2, axis=0) log_z_given_y_phi = compute_log_z_given_y(enc_eta1, enc_eta2, gmm_eta1, gmm_eta2, gmm_pi) # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * gmm_eta2)]): mu, cov = gaussian.natural_to_standard(eta1_tilde, eta2_tilde) posterior_mixture_distribution = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(tf.exp(log_z_given_y_phi)), components_distribution=tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov)) # sample x for each of the K components # latent_k_samples.shape == nb_samples, batch_size, nb_components, latent_dim latent_k_samples = posterior_mixture_distribution.components_distribution.sample( [nb_samples]) ### vae decode output_mean = snt.BatchApply(self._decoder, n_dims=3)(latent_k_samples) output_variance = tf.get_variable( 'output_variance', dtype=tf.float32, initializer=tf.zeros(output_mean.get_shape().as_list()), trainable=True) # learned parameter for output distribution output_distribution = tfd.Independent( tfd.MultivariateNormalDiagWithSoftplusScale( loc=output_mean, scale_diag=output_variance), reinterpreted_batch_ndims=2) # subsample for each datum in minibatch (go from `nb_samples` per component to `nb_samples` total) latent_samples = subsample_x( tf.transpose(latent_k_samples, [1, 0, 2, 3]), log_z_given_y_phi, seed) return output_distribution, posterior_mixture_distribution, latent_k_samples, latent_samples, log_z_given_y_phi
def make_decoder(code, data_shape): x = code x = tf.layers.dense(x, hidden, tf.nn.relu) x = tf.layers.dense(x, hidden, tf.nn.relu) logit = tf.layers.dense(x, np.prod(data_shape)) logit = tf.reshape(logit, [-1] + data_shape) return tfd.Independent(tfd.Bernoulli(logit), 2)
def make_decoder(z, x_shape=(x_dim,)): ''' Decoder: p(x|z) ''' with tf.variable_scope("decoder"): net = make_nn(z, x_dim) print('decoder net', net) logits = tf.reshape(net, tf.concat([[nb_z_samples, -1], x_shape], axis=0)) # For the batch print('logits', logits) return tfd.Independent(tfd.Bernoulli(logits), reinterpreted_batch_ndims=1)
def _build(self, inputs, hvar_labels, n_samples=10, analytic_kl=True): datum_shape = inputs.get_shape().as_list()[1:] enc_repr = self._encoder(inputs) self.hvar_prior = tfd.ExpRelaxedOneHotCategorical( temperature=self._temperature, logits=hvar_labels) self.hvar_posterior = tfd.ExpRelaxedOneHotCategorical( temperature=self._temperature, logits=self._hvar(enc_repr)) hvar_sample_shape = [n_samples ] + self.hvar_posterior.batch_shape.as_list( ) + self.hvar_posterior.event_shape.as_list() hvar_sample = tf.reshape(self.hvar_posterior.sample(n_samples), hvar_sample_shape) self.latent_posterior = self._latent_posterior_fn( self._loc(enc_repr), self._scale(enc_repr)) latent_posterior_sample = self.latent_posterior.sample(n_samples) joint_sample = tf.concat([hvar_sample, latent_posterior_sample], axis=-1) sample_decoder = snt.BatchApply(self._decoder) self.output_distribution = tfd.Independent( tfd.Bernoulli(logits=sample_decoder(joint_sample)), reinterpreted_batch_ndims=len(datum_shape)) distortion = -self.output_distribution.log_prob(inputs) if analytic_kl and n_samples == 1: rate = tfd.kl_divergence(self.latent_posterior, self.latent_prior) else: rate = (self.latent_posterior.log_prob(latent_posterior_sample) - self.latent_prior.log_prob(latent_posterior_sample)) hrate = self.hvar_posterior.log_prob( hvar_sample) - self.hvar_prior.log_prob(hvar_sample) # hrate = tf.Print(hrate, [temperature]) # hrate = tf.Print(hrate, [hvar_sample], summarize=10) # hrate = tf.Print(hrate, [self.hvar_posterior.log_prob(hvar_sample)]) # hrate = tf.Print(hrate, [self.hvar_prior.log_prob(hvar_sample)]) # hrate = tf.Print(hrate, [hrate], summarize=10) elbo_local = -(rate + hrate + distortion) self.elbo = tf.reduce_mean(elbo_local) self.importance_weighted_elbo = tf.reduce_mean( tf.reduce_logsumexp(elbo_local, axis=0) - tf.log(tf.to_float(n_samples))) self.hvar_sample = tf.exp(tf.split(hvar_sample, n_samples)[0]) self.hvar_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=hvar_labels, logits=tf.split(hvar_sample, n_samples)[0]) self.hvar_labels = hvar_labels self.distortion = distortion self.rate = rate self.hrate = hrate
def _reconstruction_loss(self): # E[log p(x|z)] if self.loss_function == 'cross_entropy': d_p_x_z = tf.distributions.Bernoulli(self.decoder) d_p_x_z = tfcd.Independent(d_p_x_z, 3) log_likelihood = d_p_x_z.log_prob(self.x) log_likelihood = tf.reduce_mean(log_likelihood, name='log_likelihood') # log_likelihood = tf.div(log_likelihood, tf.reduce_prod(self.x.shape)) elif self.loss_function == 'mse': log_likelihood = tf.reduce_sum(tf.squared_difference(self.x, self.decoder), axis=(1, 2, 3)) else: raise NotImplementedError() return log_likelihood
def make_decoder(self, z): x = tf.layers.dense(z, 200, tf.nn.relu) x = tf.layers.dense(x, 200, tf.nn.relu) logits = tf.layers.dense(x, self.m_variants) return tfd.Independent(tfd.Binomial(logits=logits, total_count=2.), reinterpreted_batch_ndims=1)
def get_log_p_1(n_dims): target_dist = tfd.Independent( tfd.NormalWithSoftplusScale(loc=tf.zeros(n_dims), scale=10 * tf.ones(n_dims))) return target_dist.log_prob