Example #1
0
 def encode(self, x, n_iters=100):
     y_ = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN,
                                self.NONLINEARITY, self.batchnorm,
                                self.phase)
     y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES)
     _, _, a = self._sample_a(x, y_, self.Z_SAMPLES)
     a_samps = tf.expand_dims(a, axis=2)
     _, _, z = self._sample_Z(x, y_, a, self.Z_SAMPLES)
     z_samps = tf.expand_dims(z, axis=2)
     for i in range(n_iters):
         h = tf.concat([x, z], axis=1)
         y_ = dgm._forward_pass_Cat(h, self.Pzx_y, self.NUM_HIDDEN,
                                    self.NONLINEARITY, self.batchnorm,
                                    self.phase)
         y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES)
         _, _, a = self._sample_a(x, y_, self.Z_SAMPLES)
         _, _, z = self._sample_Z(x, y_, a, self.Z_SAMPLES)
         a_samps = tf.concat([a_samps, tf.expand_dims(a, axis=2)], axis=2)
         z_samps = tf.concat([z_samps, tf.expand_dims(z, axis=2)], axis=2)
     return tf.reduce_mean(z_samps, axis=2)
Example #2
0
 def _unlabeled_loss(self, x):
     """ Compute necessary terms for unlabeled loss (per data point) """
     weights = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN,
                                     self.NONLINEARITY, self.batchnorm,
                                     self.phase)
     EL_l = 0
     for i in range(self.NUM_CLASSES):
         y = self._generate_class(i, x.get_shape()[0])
         EL_l += tf.multiply(weights[:, i], self._labeled_loss(x, y))
     ent_qy = -tf.reduce_sum(tf.multiply(weights, tf.log(1e-10 + weights)),
                             axis=1)
     return EL_l + ent_qy, EL_l, ent_qy
Example #3
0
    def _sample_xy(self, n_samples=int(1e3)):
	saver = tf.train.Saver()
	with tf.Session() as session:
            ckpt = tf.train.get_checkpoint_state(self.ckpt_dir)
            saver.restore(session, ckpt.model_checkpoint_path)
	    self.phase=False
            z_ = np.random.normal(size=(n_samples, self.Z_DIM)).astype('float32')
            if self.TYPE_PX=='Gaussian':
                x_ = dgm._forward_pass_Gauss(z_, self.Pz_x, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase)[0]
            else:
            	x_ = dgm._forward_pass_Bernoulli(z_, self.Pz_x, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase)
            h = tf.concat([x_,z_], axis=1)
            y_ = dgm._forward_pass_Cat(h, self.Pzx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase)
            x,y = session.run([x_,y_])
        return x,y
Example #4
0
 def _predict_condition_W(self, x, n_iters=20):
     y_ = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN,
                                self.NONLINEARITY, self.batchnorm,
                                self.phase)
     yq = y_
     y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES)
     y_samps = tf.expand_dims(y_, axis=2)
     for i in range(n_iters):
         _, _, z = self._sample_Z(x, y_, self.Z_SAMPLES)
         h = tf.concat([x, z], axis=1)
         y_ = dgm._forward_pass_Cat_bnn(h, self.Wtilde, self.Pzx_y,
                                        self.NUM_HIDDEN, self.NONLINEARITY,
                                        self.batchnorm, self.phase)
         y_samps = tf.concat([y_samps, tf.expand_dims(y_, axis=2)], axis=2)
         y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES)
     return tf.reduce_mean(y_samps, axis=2), yq
Example #5
0
 def predict(self, x):
     return (dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN,
                                   self.NONLINEARITY, self.batchnorm,
                                   self.phase), )