def L(x_recon, x, y, z): if self.distributions['p_z'] == 'gaussian_marg': log_prior_z = tf.reduce_sum(utils.tf_gaussian_marg(z[1], z[2]), 1) elif self.distributions['p_z'] == 'gaussian': log_prior_z = tf.reduce_sum(utils.tf_stdnormal_logpdf(z[0]), 1) if self.distributions['p_y'] == 'uniform': y_prior = (1. / self.dim_y) * tf.ones_like(y) log_prior_y = -tf.nn.softmax_cross_entropy_with_logits( labels=y_prior, logits=y) if self.distributions['p_x'] == 'gaussian': log_lik = tf.reduce_sum( utils.tf_normal_logpdf(x, x_recon[0], x_recon[1]), 1) if self.distributions['q_z'] == 'gaussian_marg': log_post_z = tf.reduce_sum(utils.tf_gaussian_ent(z[2]), 1) elif self.distributions['q_z'] == 'gaussian': log_post_z = tf.reduce_sum( utils.tf_normal_logpdf(z[0], z[1], z[2]), 1) _L = log_prior_y + log_lik + log_prior_z - log_post_z return _L
def _objective(self): ############ ''' Cost ''' ############ self.z_sample, self.z_mu, self.z_lsgms = self._generate_zx(self.x) # print("z_sample: ", self.z_sample.shape()) # print("z_mu: ", self.z_mu.shape()) # print("z_lsgms: ", self.z_lsgms.shape()) self.x_recon, self.x_recon_logits = self._generate_xz(self.z_sample) if self.distributions['p_z'] == 'gaussian_marg': prior_z = tf.reduce_sum( utils.tf_gaussian_marg(self.z_mu, self.z_lsgms), 1) if self.distributions['q_z'] == 'gaussian_marg': post_z = tf.reduce_sum(utils.tf_gaussian_ent(self.z_lsgms), 1) if self.distributions['p_x'] == 'bernoulli': origin_x = tf.reshape(self.x, [self.batch_size, self.dim_x]) self.log_lik = -tf.reduce_sum( utils.tf_binary_xentropy(origin_x, self.x_recon), 1) l2 = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) # print("post_z: ", post_z.shape) # print("prior_z: ", prior_z.shape) # print("log_like: ", self.log_lik.shape) self.cost = tf.reduce_mean(post_z - prior_z - self.log_lik) + self.l2_loss * l2 ################## ''' Evaluation ''' ################## self.z_sample_eval, _, _ = self._generate_zx(self.x, phase=pt.Phase.test, reuse=True) self.x_recon_eval, _ = self._generate_xz(self.z_sample_eval, phase=pt.Phase.test, reuse=True) self.eval_log_lik = -tf.reduce_mean( tf.reduce_sum(utils.tf_binary_xentropy(self.x, self.x_recon_eval), 1))