def _encoder(self, x, reuse=False): # Encoder models the probability P(z/X) # Network Architecture is exactly same as in infoGAN ( # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC62*4 with tf.variable_scope("encoder", reuse=reuse): self.conv1 = lrelu( conv2d(x, self.n[0], 3, 3, 2, 2, name='en_conv1')) self.conv2 = lrelu((conv2d(self.conv1, self.n[1], 3, 3, 2, 2, name='en_conv2'))) self.reshaped_en = tf.reshape(self.conv2, [self.batch_size, -1]) self.dense2_en = lrelu( linear(self.reshaped_en, self.n[2], scope='en_fc3')) # net_before_gauss = tf.print('shape of net is ', tf.shape(net)) # with tf.control_dependencies([net_before_gauss]): gaussian_params = linear(self.dense2_en, 2 * self.z_dim, scope='en_fc4') # The mean parameter is unconstrained mean = gaussian_params[:, :self.z_dim] # The standard deviation must be positive. Parametrize with a softplus and # add a small epsilon for numerical stability stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, self.z_dim:]) return mean, stddev
def decoder(self, z, reuse=False): # Models the probability P(X/z) # Network Architecture is exactly same as in infoGAN ( # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S with tf.variable_scope("decoder", reuse=reuse): dense1 = lrelu((linear(z, self.n[2], scope='de_fc1'))) dense2 = lrelu((linear(dense1, self.n[1] * 7 * 7))) reshaped = tf.reshape(dense2, [self.batch_size, 7, 7, self.n[1]]) deconv1 = lrelu( deconv2d(reshaped, [self.batch_size, 14, 14, self.n[0]], 3, 3, 2, 2, name='de_dc3')) out = tf.nn.sigmoid( deconv2d(deconv1, [self.batch_size, 28, 28, 1], 3, 3, 2, 2, name='de_dc4')) # out = lrelu(deconv2d(deconv1, [self.batch_size, 28, 28, 1], 3, 3, 2, 2, name='de_dc4')) return out
def _build_model(self): # some parameters image_dims = [self.input_height, self.input_width, self.c_dim] bs = self.batch_size """ Graph Input """ # images self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') # random vectors with multi-variate gaussian distribution # 0 mean and covariance matrix as Identity self.standard_normal = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') # Whether the sample was manually annotated. self.is_manual_annotated = tf.placeholder(tf.float32, [bs], name="is_manual_annotated") self.labels = tf.placeholder(tf.float32, [bs, self.label_dim], name='manual_label') """ Loss Function """ # encoding, self.sigma = self._encoder(self.inputs, reuse=False) # sampling by re-parameterization technique self.z = + self.sigma * tf.random_normal( tf.shape(, 0, 1, dtype=tf.float32) # supervised loss for labelled samples self.y_pred = linear(self.z, 10) self.supervised_loss = tf.losses.softmax_cross_entropy( onehot_labels=self.labels, logits=self.y_pred, weights=self.is_manual_annotated) # decoding out = self.decoder(self.z, reuse=False) self.out = tf.clip_by_value(out, 1e-8, 1 - 1e-8) # loss marginal_likelihood = tf.reduce_sum( self.inputs * tf.log(self.out) + (1 - self.inputs) * tf.log(1 - self.out), [1, 2]) # marginal_likelihood = -tf.losses.mean_squared_error(self.inputs, self.out) # marginal_likelihood = tf.reduce_sum(tf.losses.mean_squared_error(self.inputs, self.out), [1, 2]) kl_divergence = 0.5 * tf.reduce_sum( tf.square( + tf.square(self.sigma) - tf.log(1e-8 + tf.square(self.sigma)) - 1, [1]) self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) self.KL_divergence = tf.reduce_mean(kl_divergence) evidence_lower_bound = -self.neg_loglikelihood - self.beta * self.KL_divergence self.loss = -evidence_lower_bound + self.supervise_weight * self.supervised_loss """ Training """ # optimizers t_vars = tf.trainable_variables() with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): self.optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ .minimize(self.loss, var_list=t_vars) """" Testing """ # for test self.fake_images = self.decoder(self.standard_normal, reuse=True) """ Summary """ nll_sum = tf.summary.scalar("Negative Log Likelihood", self.neg_loglikelihood) kl_sum = tf.summary.scalar("K L Divergence", self.KL_divergence) supervised_loss = tf.summary.scalar("Supervised Loss", self.supervised_loss) loss_sum = tf.summary.scalar("Total Loss", self.loss) # final summary operations self.merged_summary_op = tf.summary.merge_all()