def build_wgan_discriminator(self, batch_local, batch_global, reuse=False, training=True, calc_perceptual_loss=False, losses=None): with tf.variable_scope('discriminator', reuse=reuse): if (calc_perceptual_loss): _, _, feats_local, _, dlocal = self.build_wgan_local_discriminator_verbose( batch_local, reuse=reuse, training=training) _, _, feats_global, _, dglobal = self.build_wgan_global_discriminator_verbose( batch_global, reuse=reuse, training=training) with tf.variable_scope('perceptual_loss', reuse=reuse): feats_local_flat = flatten(feats_local, "flatten_local") feats_global_flat = flatten(feats_global, "flatten_global") fl_neg, fl_pos = tf.split(feats_local_flat, 2) if ('perceptual_loss' not in losses): losses['perceptual_loss'] = 0 losses['perceptual_loss'] += self.get_perceptual_loss( fl_pos, fl_neg, name="loss_local") fg_neg, fg_pos = tf.split(feats_global_flat, 2) losses['perceptual_loss'] += self.get_perceptual_loss( fg_pos, fg_neg, name="loss_global") else: dlocal = self.build_wgan_local_discriminator(batch_local, reuse=reuse, training=training) dglobal = self.build_wgan_global_discriminator( batch_global, reuse=reuse, training=training) dout_local = tf.layers.dense(dlocal, 1, name='dout_local_fc') dout_global = tf.layers.dense(dglobal, 1, name='dout_global_fc') return dout_local, dout_global
def build_wgan_global_discriminator(self, x, reuse=False, training=True): with tf.variable_scope('discriminator_global', reuse=reuse): cnum = 64 x = dis_conv(x, cnum, name='conv1', training=training) x = dis_conv(x, cnum*2, name='conv2', training=training) x = dis_conv(x, cnum*4, name='conv3', training=training) x = dis_conv(x, cnum*4, name='conv4', training=training) x = flatten(x, name='flatten') return x
def build_sn_patch_gan_discriminator(self, x, reuse=False, training=True): with tf.compat.v1.variable_scope('sn_patch_gan', reuse=reuse): cnum = 64 x = dis_conv(x, cnum, name='conv1', training=training) x = dis_conv(x, cnum*2, name='conv2', training=training) x = dis_conv(x, cnum*4, name='conv3', training=training) x = dis_conv(x, cnum*4, name='conv4', training=training) x = dis_conv(x, cnum*4, name='conv5', training=training) x = dis_conv(x, cnum*4, name='conv6', training=training) x = flatten(x, name='flatten') return x
def build_wgan_global_discriminator_verbose(self, x, reuse=False, training=True): with tf.variable_scope('discriminator_global', reuse=reuse): cnum = 64 x1 = dis_conv(x, cnum, name='conv1', training=training) x2 = dis_conv(x1, cnum * 2, name='conv2', training=training) x3 = dis_conv(x2, cnum * 4, name='conv3', training=training) x4 = dis_conv(x3, cnum * 4, name='conv4', training=training) x5 = flatten(x4, name='flatten') return x1, x2, x3, x4, x5