def _loss(self): # All v2 losses if self.no_training_images: self.encoder_cost = 0 self.discriminator_loss = 0 else: super(InvertorDefenseGAN, self)._loss() # Fake samples should fool the discriminator self.gen_samples_faking_loss = self.gen_samples_faking_loss_scale * generator_loss( 'dcgan', self.gen_cycled_disc, ) # The latents of the encoded samples should be close to the zs self.latents_to_sample_zs = self.latents_to_z_loss_scale * tf.losses.mean_squared_error( self.z_samples, self.generator_samples_latents, reduction=Reduction.MEAN, ) tf.summary.scalar( 'losses/latents to zs loss', self.latents_to_sample_zs, ) # The cycled back reconstructions raw_cycled_reconstruction_error = slim.flatten( tf.reduce_mean( tf.abs(self.cycled_back_generator - self.generator_samples), axis=1, )) tf.summary.histogram( 'raw cycled reconstruction error', raw_cycled_reconstruction_error, ) self.cycled_reconstruction_loss = self.rec_cycled_loss_scale * tf.reduce_mean( tf.nn.relu(raw_cycled_reconstruction_error - self.rec_margin)) tf.summary.scalar('losses/cycled_margin_rec', self.cycled_reconstruction_loss) self.encoder_cost += (self.cycled_reconstruction_loss + self.gen_samples_faking_loss + self.latents_to_sample_zs) # Discriminator loss self.gen_samples_disc_loss = self.gen_samples_disc_loss_scale * discriminator_loss( 'dcgan', self.gen_samples_disc, self.gen_cycled_disc, ) tf.summary.scalar( 'losses/gen_samples_disc_loss', self.gen_samples_disc_loss, ) tf.summary.scalar( 'losses/gen_samples_faking_loss', self.gen_samples_faking_loss, ) self.discriminator_loss += self.gen_samples_disc_loss
def _loss(self): """Builds the loss part of the graph..""" self.generator_cost = generator_loss(self.mode, self.disc_fake) self.discriminator_cost = discriminator_loss(self.mode, self.disc_real, self.disc_fake) self.clip_disc_weights = None if self.mode == 'wgan': clip_ops = [] for var in tflib.params_with_name('Discriminator'): clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) self.clip_disc_weights = tf.group(*clip_ops) elif self.mode == 'wgan-gp': alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1) differences = self.fake_data - self.real_data interpolates = self.real_data + (alpha * differences) gradients = tf.gradients(self.discriminator_fn(interpolates), [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2) self.discriminator_cost += self.gradient_penalty_lambda * gradient_penalty # define optimizer op self.gen_train_op = tf.train.AdamOptimizer( learning_rate=self.generator_lr, beta1=0.5).minimize(self.generator_cost, var_list=self.g_vars) self.disc_train_op = tf.train.AdamOptimizer( learning_rate=self.discriminator_lr, beta1=0.5).minimize(self.discriminator_cost, var_list=self.d_vars) # summary writer g_loss_summary_op = tf.summary.scalar('g_loss', self.generator_cost) d_loss_summary_op = tf.summary.scalar('d_loss', self.discriminator_cost) self.merged_summary_op = tf.summary.merge_all()
def _loss(self): """Builds the loss part of the graph..""" # Loss terms raw_reconstruction_error = slim.flatten( tf.reduce_mean( tf.abs(self.enc_reconstruction - self.real_data), axis=1, )) tf.summary.histogram('raw reconstruction error', raw_reconstruction_error) image_rec_loss = self.rec_loss_scale * tf.reduce_mean( tf.nn.relu(raw_reconstruction_error - self.rec_margin)) tf.summary.scalar('losses/margin_rec', image_rec_loss) self.enc_rec_faking_loss = generator_loss( 'dcgan', self.disc_enc_rec, ) self.enc_rec_disc_loss = self.rec_disc_loss_scale * discriminator_loss( 'dcgan', self.disc_real, self.disc_enc_rec, ) tf.summary.scalar('losses/enc_recon_faking_disc', self.enc_rec_faking_loss) self.latent_reg_loss = self.latent_reg_loss_scale * tf.reduce_mean( tf.square(self.encoder_latent_before)) tf.summary.scalar('losses/latent_reg', self.latent_reg_loss) self.encoder_cost = ( image_rec_loss + self.rec_disc_loss_scale * self.enc_rec_faking_loss + self.latent_reg_loss) self.discriminator_loss = self.enc_rec_disc_loss tf.summary.scalar('losses/encoder_loss', self.encoder_cost) tf.summary.scalar('losses/discriminator_loss', self.enc_rec_disc_loss)