Example #1
0
def adversarial_loss(net_d,
                     real,
                     fake_abgr,
                     distorted,
                     gan_training="mixup_LSGAN",
                     **weights):
    """ Adversarial Loss Function from Shoanlu GAN """
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1 - alpha) * distorted

    if gan_training == "mixup_LSGAN":
        dist = Beta(0.2, 0.2)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
            [fake, distorted])
        pred_fake = net_d(concatenate([fake, distorted]))
        pred_mixup = net_d(mixup)
        loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
        loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake),
                                            "l2")
        mixup2 = lam * concatenate(
            [real, distorted]) + (1 - lam) * concatenate([fake_bgr, distorted])
        pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
        pred_mixup2 = net_d(mixup2)
        loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
        loss_g += weights['w_D'] * calc_loss(pred_fake_bgr,
                                             K.ones_like(pred_fake_bgr), "l2")
    elif gan_training == "relativistic_avg_LSGAN":
        real_pred = net_d(concatenate([real, distorted]))
        fake_pred = net_d(concatenate([fake, distorted]))
        loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred))) / 2
        loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred))) / 2
        loss_g = weights['w_D'] * K.mean(
            K.square(fake_pred - K.ones_like(fake_pred)))

        fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
        loss_d += K.mean(
            K.square(real_pred - K.mean(fake_pred2, axis=0) -
                     K.ones_like(fake_pred2))) / 2
        loss_d += K.mean(
            K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                     K.zeros_like(fake_pred2))) / 2
        loss_g += weights['w_D'] * K.mean(
            K.square(real_pred - K.mean(fake_pred2, axis=0) -
                     K.zeros_like(fake_pred2))) / 2
        loss_g += weights['w_D'] * K.mean(
            K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                     K.ones_like(fake_pred2))) / 2
    else:
        raise ValueError(
            "Receive an unknown GAN training method: {gan_training}")
    return loss_d, loss_g
Example #2
0
def perceptual_loss(real, fake_abgr, distorted, vggface_feats, **weights):
    """ Perceptual Loss Function from Shoanlu GAN """
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1 - alpha) * distorted

    def preprocess_vggface(var_x):
        var_x = (var_x + 1.) / 2. * 255.  # channel order: BGR
        var_x -= [91.4953, 103.8827, 131.0912]
        return var_x

    real_sz224 = tf.image.resize_images(real, [224, 224])
    real_sz224 = Lambda(preprocess_vggface)(real_sz224)
    dist = Beta(0.2, 0.2)
    lam = dist.sample(
    )  # use mixup trick here to reduce foward pass from 2 times to 1.
    mixup = lam * fake_bgr + (1 - lam) * fake
    fake_sz224 = tf.image.resize_images(mixup, [224, 224])
    fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
    real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(
        real_sz224)
    fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(
        fake_sz224)

    # Apply instance norm on VGG(ResNet) features
    # From MUNIT https://github.com/NVlabs/MUNIT
    loss_g = 0

    def instnorm():
        return InstanceNormalization()

    loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
                                             instnorm()(real_feat7), "l2")
    loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
                                             instnorm()(real_feat28), "l2")
    loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
                                             instnorm()(real_feat55), "l2")
    loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
                                             instnorm()(real_feat112), "l2")
    return loss_g