Ejemplo n.º 1
0
 def discriminatorTerm(self, x, y, n_samps=5):
     """ compute log p(y|x,a) with a~p(a|x) """
     _, _, a = self.sample_pa(x)
     x_ = tf.tile(tf.expand_dims(x, 0), [self.mc_samples, 1, 1])
     py_in = tf.reshape(tf.concat([x_, a], -1), [-1, self.n_x + self.n_a])
     self.W = bnn.sampleCatBNN(self.py_xa, self.n_hid)
     preds = dgm.forwardPassCatLogits(py_in,
                                      self.W,
                                      self.n_hid,
                                      self.nonlinearity,
                                      self.bn,
                                      scope='py_xa')
     preds = tf.expand_dims(tf.reduce_mean(
         tf.reshape(preds, [self.mc_samples, -1, self.n_y]), 0),
                            axis=0)
     for sample in range(n_samps - 1):
         self.W = bnn.sampleCatBNN(self.py_xa, self.n_hid)
         preds_new = dgm.forwardPassCatLogits(py_in,
                                              self.W,
                                              self.n_hid,
                                              self.nonlinearity,
                                              self.bn,
                                              scope='py_xa')
         logits = tf.reduce_mean(tf.reshape(
             preds_new, [self.mc_samples, -1, self.n_y]),
                                 axis=0)
         preds = tf.concat([preds, tf.expand_dims(logits, 0)], axis=0)
     return dgm.multinoulliLogDensity(y, tf.reduce_mean(preds, 0))
Ejemplo n.º 2
0
    def qy_loss(self, x, y):
	""" compute the labeled penalty term of q(y|x) """
	_, _, z = self.sample_z(x)
	x_ = tf.tile(tf.expand_dims(x,0), [self.mc_samples, 1,1])
	qy_in = tf.reshape(tf.concat([x_, z], axis=-1), [-1,self.n_x+self.n_z])
        y_ = dgm.forwardPassCatLogits(qy_in, self.qy_xz, self.n_hid, self.nonlinearity, self.bn, scope='qy_xz')
	y_ = tf.reduce_mean(tf.reshape(y_, [self.mc_samples, -1, self.n_y]), axis=0)
        return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 3
0
 def qy_loss(self, x, y):
     y_ = dgm.forwardPassCatLogits(x,
                                   self.qy_x,
                                   self.n_hid,
                                   self.nonlinearity,
                                   self.bn,
                                   scope='qy_x')
     return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 4
0
    def discriminatorTerm(self, x, y, n_samps=5):
	self.W = bnn.sampleCatBNN(self.py_x, self.n_hid)
	preds = dgm.forwardPassCatLogits(x, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_x')
	preds = tf.expand_dims(preds, axis=0)
	for sample in range(n_samps-1):
	    self.W = bnn.sampleCatBNN(self.py_x, self.n_hid)
	    preds_new = dgm.forwardPassCatLogits(x, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_x') 
	    preds = tf.concat([preds, tf.expand_dims(preds_new, axis=0)], axis=0)
	return dgm.multinoulliLogDensity(y, tf.reduce_mean(preds,0))
Ejemplo n.º 5
0
 def qy_term(self, x, y, a):
     """ expected additional penalty under q(y|a,x) with samples from a"""
     qy_in = tf.reshape(tf.concat([x, a], axis=-1),
                        [-1, self.n_x + self.n_a])
     y_ = tf.reshape(
         dgm.forwardPassCatLogits(qy_in, self.qy_xa, self.n_hid,
                                  self.nonlinearity, self.bn, True,
                                  'qy_xa'), [self.mc_samples, -1, self.n_y])
     return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 6
0
 def lowerBound(self, x, y, z, z_m, z_lv, z_m_p, z_lv_p):
     """ Compute densities and lower bound given all inputs
      are of shape (mc_samps X n_obs X n_dim)
      """
     l_px = self.compute_logpx(x, z)
     l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True)
     l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p)
     l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
     return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
Ejemplo n.º 7
0
 def qy_loss(self, x, y):
     """ compute the labeled penalty term of q(y|x) """
     y_ = dgm.forwardPassCatLogits(x,
                                   self.qy_x,
                                   self.n_hid,
                                   self.nonlinearity,
                                   self.bn,
                                   scope='qy_x')
     return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 8
0
    def compute_loss(self, x, y):
	yr = tf.tile(tf.expand_dims(y,0), [self.wSamps,1,1])
	if self.y_dist == 'gaussian':
	    ym, ylv = self.predict(x, self.wSamps, training=True)
	    l = tf.reduce_sum(tf.reduce_mean(dgm.gaussianLogDensity(yr, ym, ylv), axis=0))
	elif self.y_dist == 'categorical':
	    y_ = self.predict(x, self.wSamps, training=True)
	    self.l = tf.reduce_sum(tf.reduce_mean(dgm.multinoulliLogDensity(yr,y_), axis=0))
	    #self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/tf.cast(tf.shape(self.x)[0], tf.float32)
	    self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/self.n_train
	return -(self.l - self.kl_term) 
Ejemplo n.º 9
0
 def compute_logpy(self, y, x, z):
     """ compute the log density of y under p(y|x,z)"""
     py_in = tf.reshape(tf.concat([x, z], axis=-1),
                        [-1, self.n_x + self.n_z])
     y_ = dgm.forwardPassCatLogits(py_in,
                                   self.py_xz,
                                   self.n_hid,
                                   self.nonlinearity,
                                   self.bn,
                                   scope='py_xz')
     y_ = tf.reshape(y_, [self.mc_samples, -1, self.n_y])
     return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 10
0
 def lowerBound(self, x, y, z, z_m, z_lv, a, a_m, a_lv, z_m_p, z_lv_p):
     """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """
     pa_in = tf.reshape(tf.concat([y, z], axis=-1),
                        [-1, self.n_y + self.n_z])
     a_m_p, a_lv_p = self.p_a_xyz_mean(pa_in), self.p_a_xyz_log_var(pa_in)
     a_m_p, a_lv_p = tf.reshape(
         a_m_p, [self.mc_samples, -1, self.n_a]), tf.reshape(
             a_lv_p, [self.mc_samples, -1, self.n_a])
     l_px = self.compute_logpx(x, z)
     l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True)
     l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p)
     l_pa = dgm.gaussianLogDensity(a, a_m_p, a_lv_p)
     l_qz = dgm.gaussianLogDensity(z, z_m, z_lv)
     l_qa = dgm.gaussianLogDensity(a, a_m, a_lv)
     return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
Ejemplo n.º 11
0
 def qy_loss(self, x, y=None, a=None, expand_y=True):
     if a is None:
         _, _, a = self.sample_a(x)
         qy_in = tf.reshape(
             tf.concat([
                 tf.tile(tf.expand_dims(x, 0), [self.mc_samples, 1, 1]), a
             ],
                       axis=-1), [-1, self.n_x + self.n_a])
     else:
         qy_in = tf.reshape(tf.concat([x, a], axis=-1),
                            [-1, self.n_x + self.n_a])
     y_ = tf.reshape(self.q_y_ax_model(qy_in),
                     [self.mc_samples, -1, self.n_y])
     if y is not None and expand_y == True:
         y = tf.tile(tf.expand_dims(y, 0), [self.mc_samples, 1, 1])
     if y is None:
         return dgm.multinoulliUniformLogDensity(y_)
     else:
         return dgm.multinoulliLogDensity(y, y_)
Ejemplo n.º 12
0
 def qy_loss(self, x, y=None):
     y_ = self.q_y_x_model(x)
     if y is None:
         return dgm.multinoulliUniformLogDensity(y_)
     else:
         return dgm.multinoulliLogDensity(y, y_)