def klWCatBNN(q, W, n_hid, dist='Gaussian'): """ estimate KL(q||p) as logp(w) - logq(w) for a categorical BNN """ l_pw, l_qw = klWBNN(q, W, n_hid, dist) w, b = W['Wout'], W['bout'] wMean, bMean, wLv, bLv = q['Wout_mean'], q['bout_mean'], q[ 'Wout_logvar'], q['bout_logvar'] l_pw += tf.reduce_sum(dgm.standardNormalLogDensity(w)) + tf.reduce_sum( dgm.standardNormalLogDensity(b)) l_qw += tf.reduce_sum(dgm.gaussianLogDensity( w, wMean, wLv)) + tf.reduce_sum(dgm.gaussianLogDensity(b, bMean, bLv)) return l_pw - l_qw
def lowerBound(self, x, y, z, z_m, z_lv): """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """ l_px = self.compute_logpx(x,y,z) l_py = dgm.multinoulliUniformLogDensity(y) l_pz = dgm.standardNormalLogDensity(z) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
def weight_prior(self): weights = [ V for V in tf.trainable_variables() if 'W' in V.name or 'kernel' in V.name ] return np.sum( [tf.reduce_sum(dgm.standardNormalLogDensity(w)) for w in weights])
def compute_prior(self): """ compute the log prior term """ weights = [ V for V in tf.trainable_variables() if 'py_xa' not in V.name ] weight_term = np.sum( [tf.reduce_sum(dgm.standardNormalLogDensity(w)) for w in weights]) return (self.l2_reg * weight_term) / self.reg_term
def lowerBound(self, x): """ Compute lower bound for one dataset """ qc_m, qc_lv, c = dgm.samplePassStatistic(x, self.qc_x, self.n_hid, self.nonlinearity, self.bn, self.mc_samples, scope='qc_x') qz_m, qz_lv, z = self.sample_qz(x, c) pz_m, pz_lv, _ = self.sample_pz(c, 1) l_px = self.compute_logpx(x,z) l_pz = dgm.gaussianLogDensity(z, pz_m, pz_lv) l_qz = dgm.gaussianLogDensity(z, qz_m, qz_lv) l_pc = dgm.standardNormalLogDensity(c) l_qc = dgm.gaussianLogDensity(c, qc_m, qc_lv) return -tf.reduce_mean(tf.reduce_sum(l_px + l_pz - l_qz, axis=1) + l_pc - l_qc, axis=0)
def lowerBound(self, x, y, z, z_m, z_lv, a, a_m, a_lv): """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """ pa_in = tf.reshape(tf.concat([x, y, z], axis=-1), [-1, self.n_x + self.n_y + self.n_z]) a_m_p, a_lv_p = self.p_a_xyz_mean(pa_in), self.p_a_xyz_log_var(pa_in) a_m_p, a_lv_p = tf.reshape( a_m_p, [self.mc_samples, -1, self.n_a]), tf.reshape( a_lv_p, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, y, z) l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True) l_pz = dgm.standardNormalLogDensity(z) l_pa = dgm.gaussianLogDensity(a, a_m_p, a_lv_p) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, a_m, a_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
def lowerBound(self, x, y, z, z_m, z_lv, a, qa_m, qa_lv): """ Helper function for loss computations. Assumes each input is a rank(3) tensor """ pa_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) pa_m, pa_lv = dgm.forwardPassGauss(pa_in, self.pa_yz, self.n_hid, self.nonlinearity, self.bn, scope='pa_yz') pa_m, pa_lv = tf.reshape(pa_m, [self.mc_samples, -1, self.n_a]), tf.reshape( pa_lv, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, y, z, a) l_py = dgm.multinoulliUniformLogDensity(y) l_pz = dgm.standardNormalLogDensity(z) l_pa = dgm.gaussianLogDensity(a, pa_m, pa_lv) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, qa_m, qa_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)