def grow(res_increase, res_decrease): x = lambda: from_rgb(layers.downscale2d(inputs, 2**res_decrease), res_increase) if res_decrease > 0: x = utils.cset( x, (res_training > res_increase), lambda: grow(res_increase+1, res_decrease-1)) x = block(x(), res_increase); y = lambda: x if res_increase > 2: y = utils.cset( y, (res_training < res_increase), lambda: utils.lerp(x, from_rgb(layers.downscale2d(inputs, 2**(res_decrease+1)), res_increase-1), res_increase-res_training)) return y()
def grow(x, res_increase, res_decrease): y = block(x, res_increase) img = lambda: layers.upscale2d(to_rgb(y, res_increase-2), 2**res_decrease) if res_increase > 2: img = utils.cset( img, (res_training < res_increase), lambda: layers.upscale2d(utils.lerp(to_rgb(y, res_increase-2), layers.upscale2d(to_rgb(x, res_increase-3)), res_increase-res_training), 2**res_decrease)) if res_decrease > 0: img = utils.cset( img, (res_training > res_increase), lambda: grow(y, res_increase + 1, res_decrease - 1)) return img()
def GradientPenalty(discriminator, image_inputs, fake_images, minibatch_size, wgan_target=1.0, wgan_lambda=10.0): with tf.compat.v1.name_scope('GradientPenalty'): mixing_factors = tf.compat.v1.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=tf.float32) mixed_images = utils.lerp(image_inputs, fake_images, mixing_factors) mixed_loss = tf.compat.v1.math.reduce_sum(discriminator(mixed_images)) mixed_grads = tf.compat.v1.gradients(mixed_loss, [mixed_images]) mixed_norms = tf.compat.v1.math.sqrt( tf.compat.v1.math.reduce_sum(tf.compat.v1.math.square(mixed_grads), axis=[1, 2, 3])) gradient_penalty = tf.compat.v1.math.square(mixed_norms - wgan_target) return gradient_penalty * (wgan_lambda / (wgan_target**2))