Exemplo n.º 1
0
def define_loss(netD,
                real,
                fake,
                vggface_feat=None,
                mixup_alpha=None,
                use_lsgan=True):
    loss_fn = get_loss_fun(use_lsgan)

    if mixup_alpha:
        dist = Beta(mixup_alpha, mixup_alpha)
        lam = dist.sample()
        mixup = lam * real + (1 - lam) * fake
        output_mixup = netD(mixup)
        loss_D = loss_fn(output_mixup, lam * K.ones_like(output_mixup))
        output_fake = netD(fake)  # dummy
        loss_G = .5 * loss_fn(output_mixup,
                              (1 - lam) * K.ones_like(output_mixup))
    else:
        output_real = netD(real)  # positive sample
        output_fake = netD(fake)  # negative sample
        loss_D_real = loss_fn(output_real, K.ones_like(output_real))
        loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))
        loss_D = loss_D_real + loss_D_fake
        loss_G = .5 * loss_fn(output_fake, K.ones_like(output_fake))
    loss_G += K.mean(K.abs(fake - real))

    if not vggface_feat is None:
        loss_G = add_perceptual_loss(loss_G,
                                     real=real,
                                     fake=fake,
                                     vggface_feat=vggface_feat)

    return loss_D, loss_G
Exemplo n.º 2
0
def perceptual_loss(real, fake_abgr, distorted, vggface_feats, weights):
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1 - alpha) * distorted

    def preprocess_vggface(x):
        x = (x + 1) / 2 * 255  # channel order: BGR
        x -= [91.4953, 103.8827, 131.0912]
        return x

    real_sz224 = tf.image.resize_images(real, [224, 224])
    real_sz224 = Lambda(preprocess_vggface)(real_sz224)
    dist = Beta(0.2, 0.2)
    lam = dist.sample()  # use mixup trick here to reduce foward pass from 2 times to 1.
    mixup = lam * fake_bgr + (1 - lam) * fake
    fake_sz224 = tf.image.resize_images(mixup, [224, 224])
    fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
    real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
    fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)

    # Apply instance norm on VGG(ResNet) features
    # From MUNIT https://github.com/NVlabs/MUNIT
    loss_G = 0

    def instnorm(): return InstanceNormalization()

    loss_G += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7), instnorm()(real_feat7), "l2")
    loss_G += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28), instnorm()(real_feat28), "l2")
    loss_G += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55), instnorm()(real_feat55), "l2")
    loss_G += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112), instnorm()(real_feat112), "l2")
    return loss_G
Exemplo n.º 3
0
def adversarial_loss(netD, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):   
    alpha = Lambda(lambda x: x[:,:,:, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:,:,:, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1-alpha) * distorted
    
    if gan_training == "mixup_LSGAN":
        dist = Beta(0.2, 0.2)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])        
        output_mixup = netD(mixup)
        loss_D = calc_loss(output_mixup, lam * K.ones_like(output_mixup), "l2")
        loss_G = weights['w_D'] * calc_loss(output_mixup, (1 - lam) * K.ones_like(output_mixup), "l2")
        mixup2 = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake_bgr, distorted])
        output_mixup2 = netD(mixup2)
        loss_D += calc_loss(output_mixup2, lam * K.ones_like(output_mixup2), "l2")
        loss_G += weights['w_D'] * calc_loss(output_mixup2, (1 - lam) * K.ones_like(output_mixup2), "l2")
    elif gan_training == "relativistic_avg_LSGAN":
        real_pred = netD(concatenate([real, distorted]))
        fake_pred = netD(concatenate([fake, distorted]))
        loss_D = K.mean(K.square(real_pred - K.mean(fake_pred,axis=0) - K.ones_like(fake_pred)))
        loss_D += K.mean(K.square(K.mean(fake_pred,axis=0) - real_pred + K.ones_like(fake_pred)))
        loss_G = weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred,axis=0) + K.ones_like(fake_pred))) 
        loss_G += weights['w_D'] * K.mean(K.square(K.mean(fake_pred,axis=0) - real_pred - K.ones_like(fake_pred)))
        
        fake_pred2 = netD(concatenate([fake_bgr, distorted]))
        loss_D += K.mean(K.square(real_pred - K.mean(fake_pred2,axis=0) - K.ones_like(fake_pred2)))
        loss_D += K.mean(K.square(K.mean(fake_pred2,axis=0) - real_pred + K.ones_like(fake_pred2)))
        loss_G += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2,axis=0) + K.ones_like(fake_pred2))) 
        loss_G += weights['w_D'] * K.mean(K.square(K.mean(fake_pred2,axis=0) - real_pred - K.ones_like(fake_pred2)))
    else:
        raise ValueError("Receive an unknown GAN training method: {gan_training}")
    return loss_D, loss_G
def D_loss(netD, real, fake, rec, alpha=None, idt=None, is_cyclic=False):
    #x_i, y_i, y_j = tf.split(real, [3, 3, 3], 3)
    x_i = Lambda(lambda x: x[:, :, :, 0:3])(real)
    y_i = Lambda(lambda x: x[:, :, :, 3:6])(real)
    y_j = Lambda(lambda x: x[:, :, :, 6:])(real)
    x_i_j = fake

    dist = Beta(mixup_alpha, mixup_alpha)
    lam = dist.sample()

    if use_mixup:
        mixup_x = lam * x_i + (1 - lam) * x_i_j
        mixup_y = lam * y_i + (1 - lam) * y_j
        output_real = netD(concatenate(
            [mixup_x, mixup_y]))  # positive sample + negative sample
        output_fake = netD(concatenate([x_i_j,
                                        y_j]))  # negative sample (dummy)
        output_fake2 = netD(concatenate([x_i, y_j]))  # negative sample 2
    else:
        output_real = netD(concatenate([x_i, y_i]))  # positive sample
        output_fake = netD(concatenate([x_i_j, y_j]))  # negative sample
        output_fake2 = netD(concatenate([x_i, y_j]))  # negative sample 2

    if not alpha is None:
        alpha_resized = tf.image.resize_images(
            alpha, [int(output_real.shape[1]),
                    int(output_real.shape[2])])
    else:
        alpha_resized = 1

    #Get Loss Function
    if use_lsgan:
        loss_fn = lambda output, target: K.mean(
            K.abs(K.square(output - target)))
    else:
        loss_fn = lambda output, target: -K.mean(
            K.log(output + 1e-12) * target + K.log(1 - output + 1e-12) *
            (1 - target))
    loss_D_real = loss_fn(output_real, lam * K.ones_like(output_real))
    loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))
    #loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake) + (1-alpha_resized)*K.ones_like(output_fake))
    loss_D_fake2 = loss_fn(output_fake2, K.zeros_like(output_fake2))
    loss_G = loss_fn(output_fake, K.ones_like(output_fake))

    if use_mixup:
        loss_D = loss_D_real + loss_D_fake2
    else:
        loss_D = loss_D_real + (loss_D_fake + loss_D_fake2)

    # cyclic consistency loss
    if is_cyclic:
        loss_cyc = K.mean(K.abs(rec - x_i))
    else:
        loss_cyc = 0

    # identity loss
    if not idt is None:
        loss_cyc += K.mean(K.abs(idt - x_i))

    return loss_D, loss_G, loss_cyc
Exemplo n.º 5
0
def define_loss(netD,
                real,
                fake_argb,
                fake_sz64,
                distorted,
                vggface_feat=None):
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_argb)
    fake_rgb = Lambda(lambda x: x[:, :, :, 1:])(fake_argb)
    fake = alpha * fake_rgb + (1 - alpha) * distorted

    if use_mixup:
        dist = Beta(mixup_alpha, mixup_alpha)
        lam = dist.sample()
        # ==========
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
            [fake, distorted])
        # ==========
        output_mixup = netD(mixup)
        loss_D = loss_fn(output_mixup, lam * K.ones_like(output_mixup))
        # output_fake = netD(concatenate([fake, distorted])) # dummy
        loss_G = 1 * loss_fn(output_mixup,
                             (1 - lam) * K.ones_like(output_mixup))
    else:
        output_real = netD(concatenate([real, distorted]))  # positive sample
        output_fake = netD(concatenate([fake, distorted]))  # negative sample
        loss_D_real = loss_fn(output_real, K.ones_like(output_real))
        loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))
        loss_D = loss_D_real + loss_D_fake
        loss_G = 1 * loss_fn(output_fake, K.ones_like(output_fake))
        # ==========
    loss_G += K.mean(K.abs(fake_rgb - real))
    loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
    # ==========

    # Perceptual Loss
    if not vggface_feat is None:

        def preprocess_vggface(x):
            x = (x + 1) / 2 * 255  # channel order: BGR
            x -= [93.5940, 104.7624, 129.]
            return x

        pl_params = (0.02, 0.3, 0.5)
        real_sz224 = tf.image.resize_images(real, [224, 224])
        real_sz224 = Lambda(preprocess_vggface)(real_sz224)
        # ==========
        fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
        fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
        # ==========
        real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
        fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
        loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
        loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
        loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))

    return loss_D, loss_G
Exemplo n.º 6
0
def define_loss_masked(netD,
                       real,
                       fake_argb,
                       distorted,
                       vggface_feat=None,
                       mixup_alpha=None,
                       use_lsgan=True):
    # loss weights
    w_D = 0.5  # Discriminator contribution to generator loss
    w_recon = 1.  # L1 reconstruction loss
    w_edge = 1.  # edge loss

    loss_fn = get_loss_fun(use_lsgan)

    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_argb)
    fake_rgb = Lambda(lambda x: x[:, :, :, 1:])(fake_argb)
    fake = alpha * fake_rgb + (1 - alpha) * distorted

    if mixup_alpha:
        dist = Beta(mixup_alpha, mixup_alpha)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
            [fake, distorted])
        output_mixup = netD(mixup)
        loss_D = loss_fn(output_mixup, lam * K.ones_like(output_mixup))
        output_fake = netD(concatenate([fake, distorted]))  # dummy
        loss_G = w_D * loss_fn(output_mixup,
                               (1 - lam) * K.ones_like(output_mixup))
    else:
        output_real = netD(concatenate([real, distorted]))  # positive sample
        output_fake = netD(concatenate([fake, distorted]))  # negative sample
        loss_D_real = loss_fn(output_real, K.ones_like(output_real))
        loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))
        loss_D = loss_D_real + loss_D_fake
        loss_G = w_D * loss_fn(output_fake, K.ones_like(output_fake))

    # Reconstruction loss
    loss_G += w_recon * K.mean(K.abs(fake_rgb - real))

    # Edge loss (similar with total variation loss)
    loss_G += w_edge * K.mean(
        K.abs(first_order(fake_rgb, axis=1) - first_order(real, axis=1)))
    loss_G += w_edge * K.mean(
        K.abs(first_order(fake_rgb, axis=2) - first_order(real, axis=2)))

    # Perceptual Loss
    if not vggface_feat is None:
        loss_G = add_perceptual_loss_masked(loss_G,
                                            real=real,
                                            fake=fake,
                                            vggface_feat=vggface_feat,
                                            fake_rgb=fake_rgb)

    return loss_D, loss_G
Exemplo n.º 7
0
def adversarial_loss(net_d,
                     real,
                     fake_abgr,
                     distorted,
                     gan_training="mixup_LSGAN",
                     **weights):
    """ Adversarial Loss Function from Shoanlu GAN """
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1 - alpha) * distorted

    if gan_training == "mixup_LSGAN":
        dist = Beta(0.2, 0.2)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
            [fake, distorted])
        pred_fake = net_d(concatenate([fake, distorted]))
        pred_mixup = net_d(mixup)
        loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
        loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake),
                                            "l2")
        mixup2 = lam * concatenate(
            [real, distorted]) + (1 - lam) * concatenate([fake_bgr, distorted])
        pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
        pred_mixup2 = net_d(mixup2)
        loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
        loss_g += weights['w_D'] * calc_loss(pred_fake_bgr,
                                             K.ones_like(pred_fake_bgr), "l2")
    elif gan_training == "relativistic_avg_LSGAN":
        real_pred = net_d(concatenate([real, distorted]))
        fake_pred = net_d(concatenate([fake, distorted]))
        loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred))) / 2
        loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred))) / 2
        loss_g = weights['w_D'] * K.mean(
            K.square(fake_pred - K.ones_like(fake_pred)))

        fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
        loss_d += K.mean(
            K.square(real_pred - K.mean(fake_pred2, axis=0) -
                     K.ones_like(fake_pred2))) / 2
        loss_d += K.mean(
            K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                     K.zeros_like(fake_pred2))) / 2
        loss_g += weights['w_D'] * K.mean(
            K.square(real_pred - K.mean(fake_pred2, axis=0) -
                     K.zeros_like(fake_pred2))) / 2
        loss_g += weights['w_D'] * K.mean(
            K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                     K.ones_like(fake_pred2))) / 2
    else:
        raise ValueError(
            "Receive an unknown GAN training method: {gan_training}")
    return loss_d, loss_g
Exemplo n.º 8
0
    def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None):
        alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
        fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
        fake = alpha * fake_rgb + (1-alpha) * distorted

        if self.use_mixup:
            dist = Beta(self.mixup_alpha, self.mixup_alpha)
            lam = dist.sample()
            # ==========
            mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
            # ==========
            output_mixup = netD(mixup)
            loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
            #output_fake = netD(concatenate([fake, distorted])) # dummy
            loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
        else:
            output_real = netD(concatenate([real, distorted])) # positive sample
            output_fake = netD(concatenate([fake, distorted])) # negative sample
            loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
            loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
            loss_D = loss_D_real + loss_D_fake
            loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake))
        # ==========
        if self.use_mask_refinement:
            loss_G += K.mean(K.abs(fake - real))
        else:
            loss_G += K.mean(K.abs(fake_rgb - real))
        loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
        # ==========

        # Perceptual Loss
        if not vggface_feat is None:
            def preprocess_vggface(x):
                x = (x + 1)/2 * 255 # channel order: BGR
                x -= [93.5940, 104.7624, 129.]
                return x
            pl_params = (0.02, 0.3, 0.5)
            real_sz224 = tf.image.resize_images(real, [224, 224])
            real_sz224 = Lambda(preprocess_vggface)(real_sz224)
            # ==========
            if self.use_mask_refinement:
                fake_sz224 = tf.image.resize_images(fake, [224, 224])
            else:
                fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
            fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
            # ==========
            real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
            fake_feat55, fake_feat28, fake_feat7  = vggface_feat(fake_sz224)
            loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
            loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
            loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))

        return loss_D, loss_G
Exemplo n.º 9
0
def kumaraswamy_kl(prior_alpha, prior_beta,a,b, x):
    """
    Implementation of the KL distribution between a Beta and a Kumaraswamy distribution.
    Code refactored from the paper "Stick breaking DGMs". Therein they used 10 terms to 
    approximate the infinite taylor series.
    
    Parameters:
        prior_alpha: float/1d, 2d
            The parameter \alpha  of a prior distribution Beta(\alpha,\beta).
        prior_beta: float/1d, 2d
            The parameter \beta of a prior distribution Beta(\alpha, \beta).
        a: float/1d,2d
            The parameter a of a posterior distribution Kumaraswamy(a,b).
        b: float/1d, 2d
            The parameter b of a posterior distribution Kumaraswamy(a,b).
            
    Returns:
        kl: float
            The KL divergence between Beta and Kumaraswamy with given parameters.
    
    """
    
    q_log_prob = kumaraswamy_log_pdf(a, b, x)
    p_log_prob  = Beta(prior_alpha, prior_beta).log_prob(x)

    return -(p_log_prob-q_log_prob)
Exemplo n.º 10
0
def adversarial_loss(netD, real, fake_abgr, distorted, **weights):
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1 - alpha) * distorted

    dist = Beta(0.2, 0.2)
    lam = dist.sample()
    mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
        [fake, distorted])
    output_mixup = netD(mixup)
    loss_D = calc_loss(output_mixup, lam * K.ones_like(output_mixup), "l2")
    loss_G = weights['w_D'] * calc_loss(
        output_mixup, (1 - lam) * K.ones_like(output_mixup), "l2")
    mixup2 = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
        [fake_bgr, distorted])
    output_mixup2 = netD(mixup2)
    loss_D += calc_loss(output_mixup2, lam * K.ones_like(output_mixup2), "l2")
    loss_G += weights['w_D'] * calc_loss(
        output_mixup2, (1 - lam) * K.ones_like(output_mixup2), "l2")
    return loss_D, loss_G
Exemplo n.º 11
0
def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):
    """ Adversarial Loss Function from Shoanlu GAN """
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1-alpha) * distorted

    if gan_training == "mixup_LSGAN":
        dist = Beta(0.2, 0.2)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
        pred_fake = net_d(concatenate([fake, distorted]))
        pred_mixup = net_d(mixup)
        loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
        loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2")
        mixup2 = lam * concatenate([real,
                                    distorted]) + (1 - lam) * concatenate([fake_bgr,
                                                                           distorted])
        pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
        pred_mixup2 = net_d(mixup2)
        loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
        loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2")
    elif gan_training == "relativistic_avg_LSGAN":
        real_pred = net_d(concatenate([real, distorted]))
        fake_pred = net_d(concatenate([fake, distorted]))
        loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2
        loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2
        loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred)))

        fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
        loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
                                  K.ones_like(fake_pred2)))/2
        loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                                  K.zeros_like(fake_pred2)))/2
        loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
                                                   K.zeros_like(fake_pred2)))/2
        loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
                                                   K.ones_like(fake_pred2)))/2
    else:
        raise ValueError("Receive an unknown GAN training method: {gan_training}")
    return loss_d, loss_g
Exemplo n.º 12
0
def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights):
    """ Perceptual Loss Function from Shoanlu GAN """
    alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
    fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
    fake = alpha * fake_bgr + (1-alpha) * distorted

    def preprocess_vggface(var_x):
        var_x = (var_x + 1) / 2 * 255  # channel order: BGR
        var_x -= [91.4953, 103.8827, 131.0912]
        return var_x

    real_sz224 = tf.image.resize_images(real, [224, 224])
    real_sz224 = Lambda(preprocess_vggface)(real_sz224)
    dist = Beta(0.2, 0.2)
    lam = dist.sample()  # use mixup trick here to reduce foward pass from 2 times to 1.
    mixup = lam*fake_bgr + (1-lam)*fake
    fake_sz224 = tf.image.resize_images(mixup, [224, 224])
    fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
    real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
    fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)

    # Apply instance norm on VGG(ResNet) features
    # From MUNIT https://github.com/NVlabs/MUNIT
    loss_g = 0

    def instnorm():
        return InstanceNormalization()

    loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
                                             instnorm()(real_feat7), "l2")
    loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
                                             instnorm()(real_feat28), "l2")
    loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
                                             instnorm()(real_feat55), "l2")
    loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
                                             instnorm()(real_feat112), "l2")
    return loss_g
Exemplo n.º 13
0
    def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None):
        alpha = Lambda(lambda x: x[:, :, :, :1])(fake_argb)
        fake_rgb = Lambda(lambda x: x[:, :, :, 1:])(fake_argb)
        fake = alpha * fake_rgb + (1 - alpha) * distorted

        if self.use_mixup:
            dist = Beta(self.mixup_alpha, self.mixup_alpha)
            lam = dist.sample()
            # ==========
            mixup = lam * concatenate(
                [real, distorted]) + (1 - lam) * concatenate([fake, distorted])
            # ==========
            output_mixup = netD(mixup)
            loss_D = self.loss_fn(output_mixup,
                                  lam * K.ones_like(output_mixup))
            output_fake = netD(concatenate([fake, distorted]))  # dummy
            loss_G = .5 * self.loss_fn(output_mixup,
                                       (1 - lam) * K.ones_like(output_mixup))
        else:
            output_real = netD(concatenate([real,
                                            distorted]))  # positive sample
            output_fake = netD(concatenate([fake,
                                            distorted]))  # negative sample
            loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
            loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
            loss_D = loss_D_real + loss_D_fake
            loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake))
        # ==========
        loss_G += K.mean(K.abs(fake_rgb - real))
        # ==========

        # Edge loss (similar with total variation loss)
        loss_G += 1 * K.mean(
            K.abs(
                self.first_order(fake_rgb, axis=1) -
                self.first_order(real, axis=1)))
        loss_G += 1 * K.mean(
            K.abs(
                self.first_order(fake_rgb, axis=2) -
                self.first_order(real, axis=2)))

        # Perceptual Loss
        if not vggface_feat is None:

            def preprocess_vggface(x):
                x = (x + 1) / 2 * 255  # channel order: BGR
                #x[..., 0] -= 93.5940
                #x[..., 1] -= 104.7624
                #x[..., 2] -= 129.
                x -= [91.4953, 103.8827, 131.0912]
                return x

            pl_params = (0.011, 0.11, 0.1919)
            real_sz224 = tf.image.resize_images(real, [224, 224])
            real_sz224 = Lambda(preprocess_vggface)(real_sz224)
            # ==========
            fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
            fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
            # ==========
            real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
            fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
            loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
            loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
            loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))

        return loss_D, loss_G
def beta_loss(target, output):
    action1_prob = Beta(output[:, 0], output[:, 1])
    a = action1_prob.log_prob(target[:, 0])
    result = tf.reduce_sum(a * target[:, 1], axis=-1)
    return -result
Exemplo n.º 15
0
# reconstruction loss
loss_x_a_recon = w_recon * K.mean(K.abs(x_a_recon - real_A))
loss_x_b_recon = w_recon * K.mean(K.abs(x_b_recon - real_B))
loss_s_a_recon = w_recon_latent * K.mean(K.abs(s_a_recon - s_a))
loss_s_b_recon = w_recon_latent * K.mean(K.abs(s_b_recon - s_b))
loss_c_a_recon = w_recon_latent * K.mean(K.abs(c_a_recon - c_a))
loss_c_b_recon = w_recon_latent * K.mean(K.abs(c_b_recon - c_b))
loss_cycrecon_x_a = w_cycrecon * K.mean(K.abs(x_aba - real_A))
loss_cycrecon_x_b = w_cycrecon * K.mean(K.abs(x_bab - real_B))
loss_GA += (loss_x_a_recon + loss_s_a_recon + loss_c_a_recon)
loss_GB += (loss_x_b_recon + loss_s_b_recon + loss_c_b_recon)

# GAN loss
if use_mixup:
    dist_beta = Beta(mixup_alpha, mixup_alpha)
    lam_A = dist_beta.sample()
    mixup_A = lam_A * real_A + (1 - lam_A) * x_ba
    outputs_xba_DA = netDA(x_ba)
    outputs_mixup_DA = netDA(mixup_A)
    for output in outputs_mixup_DA:
        loss_DA += loss_fn(output, lam_A * K.ones_like(output))  
    for output in outputs_xba_DA:
        loss_adv_GA += w_D * loss_fn(output, K.ones_like(output)) 

    lam_B = dist_beta.sample()
    mixup_B = lam_B * real_B + (1 - lam_B) * x_ab
    outputs_xab_DB = netDB(x_ab)
    outputs_mixup_DB = netDB(mixup_B)
    for output in outputs_mixup_DB:
        loss_DB += loss_fn(output, lam_B * K.ones_like(output))  
Exemplo n.º 16
0
    def define_loss(self,
                    netD,
                    netD2,
                    netD_feat,
                    netD_code,
                    netG,
                    real,
                    fake_argb,
                    fake_sz64,
                    distorted,
                    domain,
                    vggface_feat=None):
        alpha = Lambda(lambda x: x[:, :, :, :1])(fake_argb)
        fake_rgb = Lambda(lambda x: x[:, :, :, 1:])(fake_argb)
        fake = alpha * fake_rgb + (1 - alpha) * distorted

        # Use mixup - Loss of masked output
        dist = Beta(self.mixup_alpha, self.mixup_alpha)
        lam = dist.sample()
        # ==========
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate(
            [fake, distorted])
        # ==========
        output_mixup = netD(mixup)
        loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
        loss_G = .5 * self.loss_fn(output_mixup,
                                   (1 - lam) * K.ones_like(output_mixup))

        # Loss of raw output
        #real_shuffled = Lambda(lambda x: tf.random_shuffle(x))(real)
        lam2 = dist.sample()
        mixup2 = lam2 * real + (1 - lam2) * fake_rgb
        output2_mixup = netD2(mixup2)
        loss_D2 = self.loss_fn(output2_mixup,
                               lam2 * K.ones_like(output2_mixup))
        loss_G += .5 * self.loss_fn(output2_mixup,
                                    (1 - lam) * K.ones_like(output2_mixup))

        # Domain adversarial loss
        real_code = netG([real])[1]
        rec_code = netG([fake_rgb])[1]
        output_real_code = netD_code([real_code])
        # Target of domain A = 1, domain B = 0
        if domain == "A":
            loss_D_code = self.loss_fn_bce(output_real_code,
                                           K.ones_like(output_real_code))
            loss_G += .03 * self.loss_fn(output_real_code,
                                         K.zeros_like(output_real_code))
        elif domain == "B":
            loss_D_code = self.loss_fn_bce(output_real_code,
                                           K.zeros_like(output_real_code))
            loss_G += .03 * self.loss_fn(output_real_code,
                                         K.ones_like(output_real_code))

        # semantic consistency loss
        loss_G += 1. * self.cos_distance(rec_code, real_code)

        # ==========
        # L1 loss
        loss_G += 3 * K.mean(K.abs(fake_rgb - real))
        loss_G += 3 * K.mean(
            K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
        # ==========

        loss_D_feat = 0
        # Perceptual Loss
        if not vggface_feat is None:

            def preprocess_vggface(x):
                x = (x + 1) / 2 * 255  # channel order: BGR
                x -= [93.5940, 104.7624, 129.]
                return x

            pl_params = (0.02, 0.3, 0.5)
            real_sz224 = tf.image.resize_images(real, [224, 224])
            real_sz224 = Lambda(preprocess_vggface)(real_sz224)
            # ==========
            fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
            fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
            # ==========
            real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
            fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
            loss_G += pl_params[0] * K.mean(K.square(fake_feat7 - real_feat7))
            loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
            loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))

        return loss_D, loss_D2, loss_G, loss_D_feat, loss_D_code