def discriminatorTerm(self, x, y, n_samps=5): """ compute log p(y|x,a) with a~p(a|x) """ _, _, a = self.sample_pa(x) x_ = tf.tile(tf.expand_dims(x, 0), [self.mc_samples, 1, 1]) py_in = tf.reshape(tf.concat([x_, a], -1), [-1, self.n_x + self.n_a]) self.W = bnn.sampleCatBNN(self.py_xa, self.n_hid) preds = dgm.forwardPassCatLogits(py_in, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_xa') preds = tf.expand_dims(tf.reduce_mean( tf.reshape(preds, [self.mc_samples, -1, self.n_y]), 0), axis=0) for sample in range(n_samps - 1): self.W = bnn.sampleCatBNN(self.py_xa, self.n_hid) preds_new = dgm.forwardPassCatLogits(py_in, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_xa') logits = tf.reduce_mean(tf.reshape( preds_new, [self.mc_samples, -1, self.n_y]), axis=0) preds = tf.concat([preds, tf.expand_dims(logits, 0)], axis=0) return dgm.multinoulliLogDensity(y, tf.reduce_mean(preds, 0))
def qy_loss(self, x, y): """ compute the labeled penalty term of q(y|x) """ _, _, z = self.sample_z(x) x_ = tf.tile(tf.expand_dims(x,0), [self.mc_samples, 1,1]) qy_in = tf.reshape(tf.concat([x_, z], axis=-1), [-1,self.n_x+self.n_z]) y_ = dgm.forwardPassCatLogits(qy_in, self.qy_xz, self.n_hid, self.nonlinearity, self.bn, scope='qy_xz') y_ = tf.reduce_mean(tf.reshape(y_, [self.mc_samples, -1, self.n_y]), axis=0) return dgm.multinoulliLogDensity(y, y_)
def qy_loss(self, x, y): y_ = dgm.forwardPassCatLogits(x, self.qy_x, self.n_hid, self.nonlinearity, self.bn, scope='qy_x') return dgm.multinoulliLogDensity(y, y_)
def discriminatorTerm(self, x, y, n_samps=5): self.W = bnn.sampleCatBNN(self.py_x, self.n_hid) preds = dgm.forwardPassCatLogits(x, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_x') preds = tf.expand_dims(preds, axis=0) for sample in range(n_samps-1): self.W = bnn.sampleCatBNN(self.py_x, self.n_hid) preds_new = dgm.forwardPassCatLogits(x, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_x') preds = tf.concat([preds, tf.expand_dims(preds_new, axis=0)], axis=0) return dgm.multinoulliLogDensity(y, tf.reduce_mean(preds,0))
def qy_term(self, x, y, a): """ expected additional penalty under q(y|a,x) with samples from a""" qy_in = tf.reshape(tf.concat([x, a], axis=-1), [-1, self.n_x + self.n_a]) y_ = tf.reshape( dgm.forwardPassCatLogits(qy_in, self.qy_xa, self.n_hid, self.nonlinearity, self.bn, True, 'qy_xa'), [self.mc_samples, -1, self.n_y]) return dgm.multinoulliLogDensity(y, y_)
def lowerBound(self, x, y, z, z_m, z_lv, z_m_p, z_lv_p): """ Compute densities and lower bound given all inputs are of shape (mc_samps X n_obs X n_dim) """ l_px = self.compute_logpx(x, z) l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True) l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) return tf.reduce_mean(l_px + l_py + l_pz - l_qz, axis=0)
def qy_loss(self, x, y): """ compute the labeled penalty term of q(y|x) """ y_ = dgm.forwardPassCatLogits(x, self.qy_x, self.n_hid, self.nonlinearity, self.bn, scope='qy_x') return dgm.multinoulliLogDensity(y, y_)
def compute_loss(self, x, y): yr = tf.tile(tf.expand_dims(y,0), [self.wSamps,1,1]) if self.y_dist == 'gaussian': ym, ylv = self.predict(x, self.wSamps, training=True) l = tf.reduce_sum(tf.reduce_mean(dgm.gaussianLogDensity(yr, ym, ylv), axis=0)) elif self.y_dist == 'categorical': y_ = self.predict(x, self.wSamps, training=True) self.l = tf.reduce_sum(tf.reduce_mean(dgm.multinoulliLogDensity(yr,y_), axis=0)) #self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/tf.cast(tf.shape(self.x)[0], tf.float32) self.kl_term = dgm.klWCatBNN_exact(self.q, self.n_hid)/self.n_train return -(self.l - self.kl_term)
def compute_logpy(self, y, x, z): """ compute the log density of y under p(y|x,z)""" py_in = tf.reshape(tf.concat([x, z], axis=-1), [-1, self.n_x + self.n_z]) y_ = dgm.forwardPassCatLogits(py_in, self.py_xz, self.n_hid, self.nonlinearity, self.bn, scope='py_xz') y_ = tf.reshape(y_, [self.mc_samples, -1, self.n_y]) return dgm.multinoulliLogDensity(y, y_)
def lowerBound(self, x, y, z, z_m, z_lv, a, a_m, a_lv, z_m_p, z_lv_p): """ Compute densities and lower bound given all inputs (mc_samps X n_obs X n_dim) """ pa_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) a_m_p, a_lv_p = self.p_a_xyz_mean(pa_in), self.p_a_xyz_log_var(pa_in) a_m_p, a_lv_p = tf.reshape( a_m_p, [self.mc_samples, -1, self.n_a]), tf.reshape( a_lv_p, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, z) l_py = dgm.multinoulliLogDensity(y, self.prior, on_priors=True) l_pz = dgm.gaussianLogDensity(z, z_m_p, z_lv_p) l_pa = dgm.gaussianLogDensity(a, a_m_p, a_lv_p) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, a_m, a_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
def qy_loss(self, x, y=None, a=None, expand_y=True): if a is None: _, _, a = self.sample_a(x) qy_in = tf.reshape( tf.concat([ tf.tile(tf.expand_dims(x, 0), [self.mc_samples, 1, 1]), a ], axis=-1), [-1, self.n_x + self.n_a]) else: qy_in = tf.reshape(tf.concat([x, a], axis=-1), [-1, self.n_x + self.n_a]) y_ = tf.reshape(self.q_y_ax_model(qy_in), [self.mc_samples, -1, self.n_y]) if y is not None and expand_y == True: y = tf.tile(tf.expand_dims(y, 0), [self.mc_samples, 1, 1]) if y is None: return dgm.multinoulliUniformLogDensity(y_) else: return dgm.multinoulliLogDensity(y, y_)
def qy_loss(self, x, y=None): y_ = self.q_y_x_model(x) if y is None: return dgm.multinoulliUniformLogDensity(y_) else: return dgm.multinoulliLogDensity(y, y_)