def compute_logpx(self, x, y, z): px_in = tf.reshape(tf.concat([y,z], axis=-1), [-1, self.n_y+self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = dgm.forwardPassGauss(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz') mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]), tf.reshape(log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def compute_logpx(self, x, z): """ compute the log density of x under p(x|z) """ px_in = tf.reshape(z, [-1, self.n_z]) if self.x_dist == 'Gaussian': mean, log_var = dgm.forwardPassGauss(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z') mean, log_var = tf.reshape(mean, [self.mc_samples, -1, self.n_x]), tf.reshape(log_var, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, log_var) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def build_model(self): self.n_train, self.n = 1,1 self.create_placeholders() if self.y_dist == 'gaussian': self.q = dgm.initGaussBNN(self.n_x, self.n_hid, self.n_y, 'network', initVar=self.initVar, bn=self.bn) self.wTilde = dgm.sampleGaussBNN(self.q, self.n_hid) self.y_m, self.y_lv = dgm.forwardPassGauss(self.x, self.wTilde, self.q, self.n_hid, self.nonlinearity, self.bn, training=True, scope='q', reuse=False) elif self.y_dist == 'categorical': self.q = dgm.initCatBNN(self.n_x, self.n_hid, self.n_y, 'network', initVar=self.initVar, bn=self.bn) self.wTilde = dgm.sampleCatBNN(self.q, self.n_hid) self.y_logits = dgm.forwardPassCatLogits(self.x, self.wTilde, self.q, self.n_hid, self.nonlinearity, self.bn, training=True, scope='q', reuse=False) self.predictions = tf.reduce_mean(self.predict(self.x, 10, training=True),0)
def lowerBound(self, x, y, z, z_m, z_lv, a, qa_m, qa_lv): """ Helper function for loss computations. Assumes each input is a rank(3) tensor """ pa_in = tf.reshape(tf.concat([y, z], axis=-1), [-1, self.n_y + self.n_z]) pa_m, pa_lv = dgm.forwardPassGauss(pa_in, self.pa_yz, self.n_hid, self.nonlinearity, self.bn, scope='pa_yz') pa_m, pa_lv = tf.reshape(pa_m, [self.mc_samples, -1, self.n_a]), tf.reshape( pa_lv, [self.mc_samples, -1, self.n_a]) l_px = self.compute_logpx(x, y, z, a) l_py = dgm.multinoulliUniformLogDensity(y) l_pz = dgm.standardNormalLogDensity(z) l_pa = dgm.gaussianLogDensity(a, pa_m, pa_lv) l_qz = dgm.gaussianLogDensity(z, z_m, z_lv) l_qa = dgm.gaussianLogDensity(a, qa_m, qa_lv) return tf.reduce_mean(l_px + l_py + l_pz + l_pa - l_qz - l_qa, axis=0)
def compute_logpx(self, x, y, z, a): """ compute the log density of x under p(x|y,z,a) """ px_in = tf.reshape(tf.concat([y, z, a], axis=-1), [-1, self.n_y + self.n_z + self.n_a]) if self.x_dist == 'Gaussian': mean, logVar = dgm.forwardPassGauss(px_in, self.px_yza, self.n_hid, self.nonlinearity, self.bn, scope='px_yza') mean, logVar = tf.reshape( mean, [self.mc_samples, -1, self.n_x]), tf.reshape( logVar, [self.mc_samples, -1, self.n_x]) return dgm.gaussianLogDensity(x, mean, logVar) elif self.x_dist == 'Bernoulli': logits = dgm.forwardPassCatLogits(px_in, self.px_yza, self.n_hid, self.nonlinearity, self.bn, scope='px_yza') logits = tf.reshape(logits, [self.mc_samples, -1, self.n_x]) return dgm.bernoulliLogDensity(x, logits)
def predictConditionW(self, x, training=True): """ return E[p(y|x, wTilde)] (assumes wTilde~q(W) has been sampled) """ if self.y_dist == 'gaussian': return dgm.forwardPassGauss(x, self.wTilde, self.q, self.n_hid, self.nonlinearity, self.bn, training, scope='q') elif self.y_dist == 'categorical': return dgm.forwardPassCatLogits(x, self.wTilde, self.q, self.n_hid, self.nonlinearity, self.bn, training, scope='q')