def calc_z_prior(self, y): z_mean_prior = dgm.forwardPass(self.p_z_y_mean, y) z_log_var_prior = dgm.forwardPass(self.p_z_y_log_var, y) z_mean_prior, z_log_var_prior = tf.reshape( z_mean_prior, [self.mc_samples, -1, self.n_z]), tf.reshape( z_log_var_prior, [self.mc_samples, -1, self.n_z]) return z_mean_prior, z_log_var_prior
def sample_z(self, x, y, n_samples=None): if n_samples == None: n_samples = self.mc_samples l_qz_in = tf.concat([x, y], axis=-1) z_mean = dgm.forwardPass(self.q_z_xy_mean, l_qz_in) z_log_var = dgm.forwardPass(self.q_z_xy_log_var, l_qz_in) return z_mean, z_log_var, dgm.sampleNormal(z_mean, z_log_var, n_samples)
def sample_a(self, x, n_samples=None): if n_samples == None: n_samples = self.mc_samples l_qa_in = x a_mean = dgm.forwardPass(self.q_a_x_mean, l_qa_in) a_log_var = dgm.forwardPass(self.q_a_x_log_var, l_qa_in) return a_mean, a_log_var, dgm.sampleNormal(a_mean, a_log_var, n_samples)
def sample_z(self, x, y, a, n_samples=None): if n_samples == None: n_samples = 1 l_qz_in = tf.reshape(tf.concat([x, y, a], axis=-1), [-1, self.n_x + self.n_y + self.n_a]) z_mean = dgm.forwardPass(self.q_z_axy_mean, l_qz_in) z_log_var = dgm.forwardPass(self.q_z_axy_log_var, l_qz_in) z = dgm.sampleNormal(z_mean, z_log_var, mc_samps=n_samples) z_mean, z_log_var = tf.reshape( z_mean, [self.mc_samples, -1, self.n_z]), tf.reshape( z_log_var, [self.mc_samples, -1, self.n_z]) z = tf.reshape(z, [self.mc_samples, -1, self.n_z]) return z_mean, z_log_var, z
def calc_z_prior(self, y): z_mean_prior = dgm.forwardPass(self.p_z_y_mean, y) z_log_var_prior = dgm.forwardPass(self.p_z_y_log_var, y) return z_mean_prior, z_log_var_prior