def encode(self, x, n_iters=100): y_ = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES) _, _, a = self._sample_a(x, y_, self.Z_SAMPLES) a_samps = tf.expand_dims(a, axis=2) _, _, z = self._sample_Z(x, y_, a, self.Z_SAMPLES) z_samps = tf.expand_dims(z, axis=2) for i in range(n_iters): h = tf.concat([x, z], axis=1) y_ = dgm._forward_pass_Cat(h, self.Pzx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES) _, _, a = self._sample_a(x, y_, self.Z_SAMPLES) _, _, z = self._sample_Z(x, y_, a, self.Z_SAMPLES) a_samps = tf.concat([a_samps, tf.expand_dims(a, axis=2)], axis=2) z_samps = tf.concat([z_samps, tf.expand_dims(z, axis=2)], axis=2) return tf.reduce_mean(z_samps, axis=2)
def _unlabeled_loss(self, x): """ Compute necessary terms for unlabeled loss (per data point) """ weights = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) EL_l = 0 for i in range(self.NUM_CLASSES): y = self._generate_class(i, x.get_shape()[0]) EL_l += tf.multiply(weights[:, i], self._labeled_loss(x, y)) ent_qy = -tf.reduce_sum(tf.multiply(weights, tf.log(1e-10 + weights)), axis=1) return EL_l + ent_qy, EL_l, ent_qy
def _sample_xy(self, n_samples=int(1e3)): saver = tf.train.Saver() with tf.Session() as session: ckpt = tf.train.get_checkpoint_state(self.ckpt_dir) saver.restore(session, ckpt.model_checkpoint_path) self.phase=False z_ = np.random.normal(size=(n_samples, self.Z_DIM)).astype('float32') if self.TYPE_PX=='Gaussian': x_ = dgm._forward_pass_Gauss(z_, self.Pz_x, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase)[0] else: x_ = dgm._forward_pass_Bernoulli(z_, self.Pz_x, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) h = tf.concat([x_,z_], axis=1) y_ = dgm._forward_pass_Cat(h, self.Pzx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) x,y = session.run([x_,y_]) return x,y
def _predict_condition_W(self, x, n_iters=20): y_ = dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) yq = y_ y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES) y_samps = tf.expand_dims(y_, axis=2) for i in range(n_iters): _, _, z = self._sample_Z(x, y_, self.Z_SAMPLES) h = tf.concat([x, z], axis=1) y_ = dgm._forward_pass_Cat_bnn(h, self.Wtilde, self.Pzx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase) y_samps = tf.concat([y_samps, tf.expand_dims(y_, axis=2)], axis=2) y_ = tf.one_hot(tf.argmax(y_, axis=1), self.NUM_CLASSES) return tf.reduce_mean(y_samps, axis=2), yq
def predict(self, x): return (dgm._forward_pass_Cat(x, self.Qx_y, self.NUM_HIDDEN, self.NONLINEARITY, self.batchnorm, self.phase), )