Ejemplo n.º 1
0
 def lowerBound(self, x, y, z, z_m, z_lv, z_m_p, z_lv_p):
     """ Compute densities and lower bound given all inputs
      are of shape (mc_samps X n_obs X n_dim)
      """
     l_px = self.compute_logpx(x, z)
     l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True)
     l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p)
     l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
     return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
Ejemplo n.º 2
0
    def lowerBound(self, x):
	""" Compute lower bound for one dataset """
	qc_m, qc_lv, c = dgm.samplePassStatistic(x, self.qc_x, self.n_hid, self.nonlinearity, self.bn, self.mc_samples, scope='qc_x')
	qz_m, qz_lv, z = self.sample_qz(x, c)
	pz_m, pz_lv, _ = self.sample_pz(c, 1)
	l_px = self.compute_logpx(x,z)
	l_pz = dgm.gaussianLogDensity(z, pz_m, pz_lv)
	l_qz = dgm.gaussianLogDensity(z, qz_m, qz_lv)
	l_pc = dgm.standardNormalLogDensity(c)
	l_qc = dgm.gaussianLogDensity(c, qc_m, qc_lv)
	return -tf.reduce_mean(tf.reduce_sum(l_px + l_pz - l_qz, axis=1) + l_pc - l_qc, axis=0) 
Ejemplo n.º 3
0
def klWCatBNN(q, W, n_hid, dist='Gaussian'):
    """ estimate KL(q||p) as logp(w) - logq(w) for a categorical BNN """
    l_pw, l_qw = klWBNN(q, W, n_hid, dist)
    w, b = W['Wout'], W['bout']
    wMean, bMean, wLv, bLv = q['Wout_mean'], q['bout_mean'], q[
        'Wout_logvar'], q['bout_logvar']
    l_pw += tf.reduce_sum(dgm.standardNormalLogDensity(w)) + tf.reduce_sum(
        dgm.standardNormalLogDensity(b))
    l_qw += tf.reduce_sum(dgm.gaussianLogDensity(
        w, wMean, wLv)) + tf.reduce_sum(dgm.gaussianLogDensity(b, bMean, bLv))
    return l_pw - l_qw
Ejemplo n.º 4
0
 def lowerBound(self, x, y, z, z_m, z_lv, a, a_m, a_lv, z_m_p, z_lv_p):
     """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """
     pa_in = tf.reshape(tf.concat([y, z], axis=-1),
                        [-1, self.n_y + self.n_z])
     a_m_p, a_lv_p = self.p_a_xyz_mean(pa_in), self.p_a_xyz_log_var(pa_in)
     a_m_p, a_lv_p = tf.reshape(
         a_m_p, [self.mc_samples, -1, self.n_a]), tf.reshape(
             a_lv_p, [self.mc_samples, -1, self.n_a])
     l_px = self.compute_logpx(x, z)
     l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True)
     l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p)
     l_pa = dgm.gaussianLogDensity(a, a_m_p, a_lv_p)
     l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
     l_qa = dgm.gaussianLogDensity(a, a_m, a_lv)
     return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
Ejemplo n.º 5
0
    def lowerBound(self, x, y, z, z_m, z_lv):
	""" Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """
	l_px = self.compute_logpx(x,y,z)
	l_py = dgm.multinoulliUniformLogDensity(y)
	l_pz = dgm.standardNormalLogDensity(z)
	l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
	return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
Ejemplo n.º 6
0
    def compute_logpx(self, x, y, z):
	px_in = tf.reshape(tf.concat([y,z], axis=-1), [-1, self.n_y+self.n_z])
	if self.x_dist == 'Gaussian':
            mean, log_var = dgm.forwardPassGauss(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz')
	    mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]),  tf.reshape(log_var, [self.mc_samples, -1, self.n_x])
            return dgm.gaussianLogDensity(x, mean, log_var)
        elif self.x_dist == 'Bernoulli':
            logits = dgm.forwardPassCatLogits(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz')
	    logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x])
            return dgm.bernoulliLogDensity(x, logits) 
Ejemplo n.º 7
0
 def lowerBound(self, x, y, z, z_m, z_lv, a, qa_m, qa_lv):
     """ Helper function for loss computations. Assumes each input is a rank(3) tensor """
     pa_in = tf.reshape(tf.concat([y, z], axis=-1),
                        [-1, self.n_y + self.n_z])
     pa_m, pa_lv = dgm.forwardPassGauss(pa_in,
                                        self.pa_yz,
                                        self.n_hid,
                                        self.nonlinearity,
                                        self.bn,
                                        scope='pa_yz')
     pa_m, pa_lv = tf.reshape(pa_m,
                              [self.mc_samples, -1, self.n_a]), tf.reshape(
                                  pa_lv, [self.mc_samples, -1, self.n_a])
     l_px = self.compute_logpx(x, y, z, a)
     l_py = dgm.multinoulliUniformLogDensity(y)
     l_pz = dgm.standardNormalLogDensity(z)
     l_pa = dgm.gaussianLogDensity(a, pa_m, pa_lv)
     l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
     l_qa = dgm.gaussianLogDensity(a, qa_m, qa_lv)
     return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
Ejemplo n.º 8
0
 def compute_logpx(self, x, z):
     px_in = tf.reshape(z, [-1, self.n_z])
     if self.x_dist == 'Gaussian':
         mean, log_var = self.p_x_z_mean(px_in), self.p_x_z_log_var(px_in)
         mean = tf.reshape(mean, [self.mc_samples, -1, self.n_x])
         log_var = tf.reshape(log_var, [self.mc_samples, -1, self.n_x])
         return dgm.gaussianLogDensity(x, mean, log_var)
     elif self.x_dist == 'Bernoulli':
         logits = self.p_x_z_mean(px_in)
         logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x])
         return dgm.bernoulliLogDensity(x, logits)
Ejemplo n.º 9
0
    def compute_logpx(self, x, z):
	""" compute the log density of x under p(x|z) """
	px_in = tf.reshape(z, [-1, self.n_z])
	if self.x_dist == 'Gaussian':
            mean, log_var = dgm.forwardPassGauss(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z')
	    mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]),  tf.reshape(log_var, [self.mc_samples, -1, self.n_x])
            return dgm.gaussianLogDensity(x, mean, log_var)
        elif self.x_dist == 'Bernoulli':
            logits = dgm.forwardPassCatLogits(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z')
	    logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x])
            return dgm.bernoulliLogDensity(x, logits) 
Ejemplo n.º 10
0
    def compute_loss(self, x, y):
	yr = tf.tile(tf.expand_dims(y,0), [self.wSamps,1,1])
	if self.y_dist == 'gaussian':
	    ym, ylv = self.predict(x, self.wSamps, training=True)
	    l = tf.reduce_sum(tf.reduce_mean(dgm.gaussianLogDensity(yr, ym, ylv), axis=0))
	elif self.y_dist == 'categorical':
	    y_ = self.predict(x, self.wSamps, training=True)
	    self.l = tf.reduce_sum(tf.reduce_mean(dgm.multinoulliLogDensity(yr,y_), axis=0))
	    #self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/tf.cast(tf.shape(self.x)[0], tf.float32)
	    self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/self.n_train
	return -(self.l - self.kl_term) 
Ejemplo n.º 11
0
 def compute_logpx(self, x, y, z):
     px_in = tf.reshape(tf.concat([y, z], axis=-1),
                        [-1, self.n_y + self.n_z])
     if self.x_dist == 'Gaussian':
         mean, log_var = self.p_x_yz_mean(px_in), self.p_x_yz_log_var(px_in)
         mean, log_var = tf.reshape(
             mean, [self.mc_samples, -1, self.n_x]), tf.reshape(
                 log_var, [self.mc_samples, -1, self.n_x])
         return dgm.gaussianLogDensity(x, mean, log_var)
     elif self.x_dist == 'Bernoulli':
         logits = self.p_x_yz_mean_model(px_in)
         logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x])
         return dgm.bernoulliLogDensity(x, logits)
Ejemplo n.º 12
0
 def compute_logpx(self, x, y, z, a):
     """ compute the log density of x under p(x|y,z,a) """
     px_in = tf.reshape(tf.concat([y, z, a], axis=-1),
                        [-1, self.n_y + self.n_z + self.n_a])
     if self.x_dist == 'Gaussian':
         mean, logVar = dgm.forwardPassGauss(px_in,
                                             self.px_yza,
                                             self.n_hid,
                                             self.nonlinearity,
                                             self.bn,
                                             scope='px_yza')
         mean, logVar = tf.reshape(
             mean, [self.mc_samples, -1, self.n_x]), tf.reshape(
                 logVar, [self.mc_samples, -1, self.n_x])
         return dgm.gaussianLogDensity(x, mean, logVar)
     elif self.x_dist == 'Bernoulli':
         logits = dgm.forwardPassCatLogits(px_in,
                                           self.px_yza,
                                           self.n_hid,
                                           self.nonlinearity,
                                           self.bn,
                                           scope='px_yza')
         logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x])
         return dgm.bernoulliLogDensity(x, logits)