def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights): """ Adversarial Loss Function from Shoanlu GAN """ alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) fake = alpha * fake_bgr + (1 - alpha) * distorted if gan_training == "mixup_LSGAN": dist = Beta(0.2, 0.2) lam = dist.sample() mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate( [fake, distorted]) pred_fake = net_d(concatenate([fake, distorted])) pred_mixup = net_d(mixup) loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2") loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2") mixup2 = lam * concatenate( [real, distorted]) + (1 - lam) * concatenate([fake_bgr, distorted]) pred_fake_bgr = net_d(concatenate([fake_bgr, distorted])) pred_mixup2 = net_d(mixup2) loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2") loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2") elif gan_training == "relativistic_avg_LSGAN": real_pred = net_d(concatenate([real, distorted])) fake_pred = net_d(concatenate([fake, distorted])) loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred))) / 2 loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred))) / 2 loss_g = weights['w_D'] * K.mean( K.square(fake_pred - K.ones_like(fake_pred))) fake_pred2 = net_d(concatenate([fake_bgr, distorted])) loss_d += K.mean( K.square(real_pred - K.mean(fake_pred2, axis=0) - K.ones_like(fake_pred2))) / 2 loss_d += K.mean( K.square(fake_pred2 - K.mean(real_pred, axis=0) - K.zeros_like(fake_pred2))) / 2 loss_g += weights['w_D'] * K.mean( K.square(real_pred - K.mean(fake_pred2, axis=0) - K.zeros_like(fake_pred2))) / 2 loss_g += weights['w_D'] * K.mean( K.square(fake_pred2 - K.mean(real_pred, axis=0) - K.ones_like(fake_pred2))) / 2 else: raise ValueError( "Receive an unknown GAN training method: {gan_training}") return loss_d, loss_g
def perceptual_loss(real, fake_abgr, distorted, vggface_feats, **weights): """ Perceptual Loss Function from Shoanlu GAN """ alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) fake = alpha * fake_bgr + (1 - alpha) * distorted def preprocess_vggface(var_x): var_x = (var_x + 1.) / 2. * 255. # channel order: BGR var_x -= [91.4953, 103.8827, 131.0912] return var_x real_sz224 = tf.image.resize_images(real, [224, 224]) real_sz224 = Lambda(preprocess_vggface)(real_sz224) dist = Beta(0.2, 0.2) lam = dist.sample( ) # use mixup trick here to reduce foward pass from 2 times to 1. mixup = lam * fake_bgr + (1 - lam) * fake fake_sz224 = tf.image.resize_images(mixup, [224, 224]) fake_sz224 = Lambda(preprocess_vggface)(fake_sz224) real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats( real_sz224) fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats( fake_sz224) # Apply instance norm on VGG(ResNet) features # From MUNIT https://github.com/NVlabs/MUNIT loss_g = 0 def instnorm(): return InstanceNormalization() loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7), instnorm()(real_feat7), "l2") loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28), instnorm()(real_feat28), "l2") loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55), instnorm()(real_feat55), "l2") loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112), instnorm()(real_feat112), "l2") return loss_g