def build_model(self): """ Define model components and variables """ self.create_placeholders() self.initialize_networks() ## model variables and relations ## # infernce # self.y_ = dgm.forwardPassCatLogits(self.x, self.qy_x, self.n_hid, self.nonlinearity, self.bn, scope='qy_x', reuse=False) self.qz_in = tf.concat([self.x, self.y], axis=-1) self.qz_mean, self.qz_lv, self.z_ = dgm.samplePassGauss( self.qz_in, self.qz_xy, self.n_hid, self.nonlinearity, self.bn, scope='qz_xy', reuse=False) # generative # self.z_prior = tf.random_normal([100, self.n_z]) if self.x_dist == 'Gaussian': self.px_mean, self.px_lv, self.x_ = dgm.samplePassGauss( self.z_prior, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z', reuse=False) elif self.x_dist == 'Bernoulli': self.x_ = dgm.forwardPassBernoulli(self.z_prior, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z', reuse=False) self.py_in = tf.concat([self.x_, self.z_prior], axis=-1) self.py_ = dgm.forwardPassCat(self.py_in, self.py_xz, self.n_hid, self.nonlinearity, self.bn, scope='py_xz', reuse=False) self.predictions = self.predict(self.x)
def encode(self, x, y=None, n_iters=100): """ encode a new example into z-space (labeled or unlabeled) """ _, _, a = dgm.samplePassGauss(x, self.qa_x, self.n_hid, self.nonlinearity, self.bn, True, 'qa_x') if y is None: h = tf.concat([x, a], axis=1) y = tf.one_hot( tf.argmax(dgm.forwardPassCat(h, self.qy_xa, self.n_hid, self.nonlinearity, self.bn, True, 'qa_x'), axis=1), self.n_y) h = tf.concat([x, y, a], axis=1) _, _, z = dgm.samplePassGauss(h, self.qz_xya, self.n_hid, self.nonlinearity, self.bn, True, 'qz_xya') return z
def sample_pa(self, x): """ return mc_samples samples from q(a|x)""" return dgm.samplePassGauss(x, self.pa_x, self.n_hid, self.nonlinearity, self.bn, mc_samps=self.mc_samples, scope='pa_x')
def build_model(self): """ Define model components and variables """ self.create_placeholders() self.initialize_networks() ## model variables and relations ## # infernce # self.qc_mean, self.qc_lv, self.c_ = dgm.samplePassStatistic(self.x, self.qc_x, self.n_hid, self.nonlinearity, self.bn, self.mc_samples, scope='qc_x', reuse=False) self.c_ = tf.reshape(self.c_, [-1, self.n_c]) self.qz_in = tf.concat([self.x, tf.tile(self.c_, [tf.shape(self.x)[0], 1])], axis=-1) self.qz_mean, self.qz_lv, self.qz_ = dgm.samplePassGauss(self.qz_in, self.qz_xc, self.n_hid, self.nonlinearity, self.bn, scope='qz_xc', reuse=False) # generative # self.c_prior = tf.random_normal([1, self.n_c]) self.pz_mean, self.pz_lv, self.pz_ = self.sample_pz(self.c_prior, 200, reuse=False) self.pz_ = tf.reshape(self.pz_, [-1, self.n_z]) if self.x_dist == 'Gaussian': self.px_mean, self.px_lv, self.px_ = dgm.samplePassGauss(self.pz_, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z', reuse=False) elif self.x_dist == 'Bernoulli': self.px_ = dgm.forwardPassBernoulli(self.pz_, self.px_z, self.n_hid, self.nonlinearity, self.bn, scope='px_z', reuse=False)
def sample_z(self, x, y): l_qz_in = tf.concat([x, y], axis=-1) return dgm.samplePassGauss(l_qz_in, self.qz_xy, self.n_hid, self.nonlinearity, self.bn, mc_samps=self.mc_samples, scope='qz_xy')
def sample_z(self, x, y): """ get parameters of and samples from q(z|x,y) """ qz_in = tf.concat([x, y], axis=-1) return dgm.samplePassGauss(qz_in, self.qz_xy, self.n_hid, self.nonlinearity, self.bn, mc_samps=self.mc_samples, scope='qz_xy')
def build_model(self): """ Define model components and variables """ self.create_placeholders() self.initialize_networks() ## model variables and relations ## # inference # self.y_ = dgm.forwardPassCatLogits(self.x, self.qy_x, self.n_hid, self.nonlinearity, self.bn, scope='qy_x', reuse=False) self.qz_in = tf.concat([self.x, self.y], axis=-1) self.qz_mean, self.qz_lv, self.z_ = dgm.samplePassGauss(self.qz_in, self.qz_xy, self.n_hid, self.nonlinearity, self.bn, scope='qz_xy', reuse=False) # generative # self.z_prior = tf.random_normal([tf.shape(self.y)[0], self.n_z]) self.px_in = tf.concat([self.y, self.z_prior], axis=-1) if self.x_dist == 'Gaussian': self.px_mean, self.px_lv, self.x_ = dgm.samplePassGauss(self.px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz', reuse=False) self.x_ = tf.reshape(self.x_, [-1, self.n_x]) elif self.x_dist == 'Bernoulli': self.x_ = dgm.forwardPassBernoulli(self.px_in, self.px_yz, self.n_hid, self.nonlinearity, self.bn, scope='px_yz', reuse=False) self.W = bnn.sampleCatBNN(self.py_x, self.n_hid) self.py = dgm.forwardPassCat(self.x_, self.W, self.n_hid, self.nonlinearity, self.bn, scope='py_x', reuse=False) self.predictions = self.predict(self.x, n_samps=50, training=False)
def sample_z(self, x, y, a): """ return mc_samples samples from q(z|x,y,a)""" l_qz_in = tf.reshape(tf.concat([x, y, a], axis=-1), [-1, self.n_x + self.n_y + self.n_z]) z_m, z_lv, z = dgm.samplePassGauss(l_qz_in, self.qz_xya, self.n_hid, self.nonlinearity, self.bn, scope='qz_xya') z_m, z_lv = tf.reshape(z_m, [self.mc_samples, -1, self.n_z]), tf.reshape( z_lv, [self.mc_samples, -1, self.n_z]) z = tf.reshape(z, [self.mc_samples, -1, self.n_z]) return z_m, z_lv, z
def sample_z(self, x, y): l_qz_in = tf.concat([x, y], axis=-1) z_m, z_lv, z = dgm.samplePassGauss(l_qz_in, self.qz_xy, self.n_hid, self.nonlinearity, self.bn, mc_samps=self.mc_samples, scope='qz_xy') return tf.tile(tf.expand_dims(z_m,0), [self.mc_samples,1,1]), tf.tile(tf.expand_dims(z_lv,0),[self.mc_samples,1,1]), z
def sample_qz(self, x, c, reuse=True): """ sample z from q(z|x,c) """ x_, c_ = tf.tile(tf.expand_dims(x,0), [self.mc_samples,1,1]), tf.tile(c, [1,tf.shape(x)[0],1]) qz_in = tf.reshape(tf.concat([x_, c_], axis=-1), [-1, self.n_x+self.n_c]) z_m, z_lv, z_ = dgm.samplePassGauss(qz_in, self.qz_xc, self.n_hid, self.nonlinearity, self.bn, scope='qz_xc', reuse=reuse) return tf.reshape(z_m, [self.mc_samples,-1,self.n_z]), tf.reshape(z_lv, [self.mc_samples,-1,self.n_z]), tf.reshape(z_, [self.mc_samples,-1,self.n_z])
def sample_pz(self, c, n_samps, reuse=True): """ generate n_samps samples of z from p(z|c) """ pz_in = tf.reshape(c, [-1, self.n_c]) z_m, z_lv, z_ = dgm.samplePassGauss(pz_in, self.pz_c, self.n_hid, self.nonlinearity, self.bn, mc_samps=n_samps, scope='pz_c', reuse=reuse) return z_m, z_lv, tf.reshape(z_, [self.mc_samples,-1, self.n_z])
def sample_z(self, x): """ get parameters of and samples from q(z|x,y) """ return dgm.samplePassGauss(x, self.qz_x, self.n_hid, self.nonlinearity, self.bn, mc_samps=self.mc_samples, scope='qz_x')