def lowerBound(self, x, y, z, z_m, z_lv, z_m_p, z_lv_p): """ Compute densities and lower bound given all inputs are of shape (mc_samps X n_obs X n_dim) """ l_px = self.compute_logpx(x, z) l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True) l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
def lowerBound(self, x): """ Compute lower bound for one dataset """ qc_m, qc_lv, c = dgm.samplePassStatistic(x, self.qc_x, self.n_hid, self.nonlinearity, self.bn, self.mc_samples, scope='qc_x') qz_m, qz_lv, z = self.sample_qz(x, c) pz_m, pz_lv, _ = self.sample_pz(c, 1) l_px = self.compute_logpx(x,z) l_pz = dgm.gaussianLogDensity(z, pz_m, pz_lv) l_qz = dgm.gaussianLogDensity(z, qz_m, qz_lv) l_pc = dgm.standardNormalLogDensity(c) l_qc = dgm.gaussianLogDensity(c, qc_m, qc_lv) return -tf.reduce_mean(tf.reduce_sum(l_px + l_pz - l_qz, axis=1) + l_pc - l_qc, axis=0)
def klWCatBNN(q, W, n_hid, dist='Gaussian'): """ estimate KL(q||p) as logp(w) - logq(w) for a categorical BNN """ l_pw, l_qw = klWBNN(q, W, n_hid, dist) w, b = W['Wout'], W['bout'] wMean, bMean, wLv, bLv = q['Wout_mean'], q['bout_mean'], q[ 'Wout_logvar'], q['bout_logvar'] l_pw += tf.reduce_sum(dgm.standardNormalLogDensity(w)) + tf.reduce_sum( dgm.standardNormalLogDensity(b)) l_qw += tf.reduce_sum(dgm.gaussianLogDensity( w, wMean, wLv)) + tf.reduce_sum(dgm.gaussianLogDensity(b, bMean, bLv)) return l_pw - l_qw
def lowerBound(self, x, y, z, z_m, z_lv, a, a_m, a_lv, z_m_p, z_lv_p): """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """ pa_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) a_m_p, a_lv_p = self.p_a_xyz_mean(pa_in), self.p_a_xyz_log_var(pa_in) a_m_p, a_lv_p = tf.reshape( a_m_p, [self.mc_samples, -1, self.n_a]), tf.reshape( a_lv_p, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, z) l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True) l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p) l_pa = dgm.gaussianLogDensity(a, a_m_p, a_lv_p) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, a_m, a_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
def lowerBound(self, x, y, z, z_m, z_lv): """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """ l_px = self.compute_logpx(x,y,z) l_py = dgm.multinoulliUniformLogDensity(y) l_pz = dgm.standardNormalLogDensity(z) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
def compute_logpx(self, x, y, z): px_in = tf.reshape(tf.concat([y,z], axis=-1), [-1, self.n_y+self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = dgm.forwardPassGauss(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz') mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]), tf.reshape(log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def lowerBound(self, x, y, z, z_m, z_lv, a, qa_m, qa_lv): """ Helper function for loss computations. Assumes each input is a rank(3) tensor """ pa_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) pa_m, pa_lv = dgm.forwardPassGauss(pa_in, self.pa_yz, self.n_hid, self.nonlinearity, self.bn, scope='pa_yz') pa_m, pa_lv = tf.reshape(pa_m, [self.mc_samples, -1, self.n_a]), tf.reshape( pa_lv, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, y, z, a) l_py = dgm.multinoulliUniformLogDensity(y) l_pz = dgm.standardNormalLogDensity(z) l_pa = dgm.gaussianLogDensity(a, pa_m, pa_lv) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, qa_m, qa_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
def compute_logpx(self, x, z): px_in = tf.reshape(z, [-1, self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = self.p_x_z_mean(px_in), self.p_x_z_log_var(px_in) mean = tf.reshape(mean, [self.mc_samples, -1, self.n_x]) log_var = tf.reshape(log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = self.p_x_z_mean(px_in) logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def compute_logpx(self, x, z): """ compute the log density of x under p(x|z) """ px_in = tf.reshape(z, [-1, self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = dgm.forwardPassGauss(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z') mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]), tf.reshape(log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def compute_loss(self, x, y): yr = tf.tile(tf.expand_dims(y,0), [self.wSamps,1,1]) if self.y_dist == 'gaussian': ym, ylv = self.predict(x, self.wSamps, training=True) l = tf.reduce_sum(tf.reduce_mean(dgm.gaussianLogDensity(yr, ym, ylv), axis=0)) elif self.y_dist == 'categorical': y_ = self.predict(x, self.wSamps, training=True) self.l = tf.reduce_sum(tf.reduce_mean(dgm.multinoulliLogDensity(yr,y_), axis=0)) #self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/tf.cast(tf.shape(self.x)[0], tf.float32) self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/self.n_train return -(self.l - self.kl_term)
def compute_logpx(self, x, y, z): px_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = self.p_x_yz_mean(px_in), self.p_x_yz_log_var(px_in) mean, log_var = tf.reshape( mean, [self.mc_samples, -1, self.n_x]), tf.reshape( log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = self.p_x_yz_mean_model(px_in) logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def compute_logpx(self, x, y, z, a): """ compute the log density of x under p(x|y,z,a) """ px_in = tf.reshape(tf.concat([y, z, a], axis=-1), [-1, self.n_y + self.n_z + self.n_a]) if self.x_dist == 'Gaussian': mean, logVar = dgm.forwardPassGauss(px_in, self.px_yza, self.n_hid, self.nonlinearity, self.bn, scope='px_yza') mean, logVar = tf.reshape( mean, [self.mc_samples, -1, self.n_x]), tf.reshape( logVar, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, logVar) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_yza, self.n_hid, self.nonlinearity, self.bn, scope='px_yza') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)