Пример #1
0
    def backward_G(self):


        self.gram_R = gram_matrix(self.feature_R)

        self.gram_B = gram_matrix(self.feature_B)

        self.loss_G_style = self.criterionMSE(self.gram_R, self.gram_B)
        #
        self.loss_G_content = self.criterionMSE(self.feature_R, self.feature_A) * 1

        # self.loss_G_L1 = self.criterionL1(self.R, self.C)
        #
        self.loss_G = self.loss_G_style + self.loss_G_content

        self.loss_G.backward()
 def calc_style_loss(x, gram_name='real_A'):
     style_loss = 0
     n_batch = len(x)
     if 'real_A' in gram_name:
         gram_style = gram_style_real_A
     elif 'real_B' in gram_name:
         gram_style = gram_style_real_B
     features_y = vgg(normalize(x))
     for ft_y, gm_s in zip(features_y, gram_style):
         gm_y = util.gram_matrix(ft_y)
         style_loss += torch.nn.MSELoss()(gm_y, gm_s[:n_batch, :, :])
     return style_loss
Пример #3
0
def extract_gram_feature(model, transform, img_path):
	input = torch.zeros(1, 1, 128, 128)
	img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
	if img.shape==(144,144):
		img = img[8:136,8:136] 
	img = np.reshape(img, (128, 128, 1))
	img = transform(img)
	input[0,:,:,:] = img
	input_var = torch.autograd.Variable(input, volatile=True)
	features = model.feat_network(input_var)

	# for m in range(len(features)):
		# print(features._fields[m])

	gram_feat = [torch.autograd.Variable(gram_matrix(y).data, requires_grad=False) for y in features]
	return gram_feat
Пример #4
0
    def load_feature_style(self):
        if not os.path.exists(self.style_dir):
            os.makedirs(self.style_dir)
        if not os.listdir(os.path.join(self.style_dir, self.style_image_name)):
            raise Exception(f"[!] No image for style transfer")

        image = load_image(os.path.join(self.style_dir, self.style_image_name),
                           size=self.image_size)
        image = transforms.Compose([
            transforms.CenterCrop(min(image.size[0], image.size[1])),
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])(image)
        image = image.repeat(self.batch_size, 1, 1, 1)
        image = image.to(self.device)
        style_image = self.vgg(image)
        self.gram_style = [gram_matrix(y) for y in style_image]
Пример #5
0
    def train(self):
        total_step = len(self.data_loader)
        optimizer = Adam(self.transfer_net.parameters(), lr=self.lr)
        loss = nn.MSELoss()
        self.transfer_net.train()

        for epoch in range(self.epoch, self.num_epoch):
            for step, image in enumerate(self.data_loader):
                image = image.to(self.device)
                transformed_image = self.transfer_net(image)

                image_feature = self.vgg(image)
                transformed_image_feature = self.vgg(transformed_image)

                content_loss = self.content_weight * loss(
                    image_feature, transformed_image_feature)

                style_loss = 0
                for ft_y, gm_s in zip(transformed_image_feature,
                                      self.gram_style):
                    gm_y = gram_matrix(ft_y)
                    style_loss += load_image(gm_y,
                                             gm_s[:self.batch_size, :, :])
                style_loss *= self.style_weight

                total_loss = content_loss + style_loss

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                if step % 10 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                        f"[Style loss: {style_loss.item()}] [Content loss loss: {content_loss.item()}]"
                    )
            torch.save(
                self.transfer_net.state_dict(),
                os.path.join(self.checkpoint_dir, f"TransferNet_{epoch}.pth"))
Пример #6
0
    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_B = self.fake_B
        # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
        if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
            # Using the cropped fake_B as the input of D.
            fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
                                            self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
            real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
                                            self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
        else:
            real_B = self.real_B

        pred_fake = self.netD(fake_B)

        if self.opt.gan_type == 'wgan_gp':
            self.loss_G_GAN = -torch.mean(pred_fake)
        else:
            if self.opt.gan_type in ['vanilla', 'lsgan']:
                self.loss_G_GAN = self.criterionGAN(pred_fake,
                                                    True) * self.opt.gan_weight

            elif self.opt.gan_type == 're_s_gan':
                pred_real = self.netD(real_B)
                self.loss_G_GAN = self.criterionGAN(pred_fake - pred_real,
                                                    True) * self.opt.gan_weight

            elif self.opt.gan_type == 're_avg_gan':
                self.pred_real = self.netD(real_B)
                self.loss_G_GAN =  (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \
                               + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2.
                self.loss_G_GAN *= self.opt.gan_weight

        # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
        self.loss_G_L1, self.loss_G_L1_m = 0, 0
        self.loss_G_L1 += self.criterionL1(self.fake_B,
                                           self.real_B) * self.opt.lambda_A
        # calcuate mask construction loss
        # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
        # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
        if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
            mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
                                                self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
            mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
                                        self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
            # Using Discounting L1 loss
            self.loss_G_L1_m += self.criterionL1_mask(
                mask_patch_fake, mask_patch_real) * self.opt.mask_weight_G

        self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN

        # Then, add TV loss
        self.loss_tv = self.tv_criterion(self.fake_B *
                                         self.mask_global.float())

        # Finally, add style loss
        vgg_ft_fakeB = self.vgg16_extractor(fake_B.repeat(1, 3, 1, 1))
        vgg_ft_realB = self.vgg16_extractor(real_B.repeat(1, 3, 1, 1))
        self.loss_style = 0
        self.loss_content = 0

        for i in range(3):
            self.loss_style += self.criterionL2_style_loss(
                util.gram_matrix(vgg_ft_fakeB[i]),
                util.gram_matrix(vgg_ft_realB[i]))
            self.loss_content += self.criterionL2_content_loss(
                vgg_ft_fakeB[i], vgg_ft_realB[i])

        self.loss_style *= self.opt.style_weight
        self.loss_content *= self.opt.content_weight

        self.loss_G += (self.loss_style + self.loss_content + self.loss_tv)

        self.loss_G.backward()
Пример #7
0
    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_B = self.fake_B
        real_B = self.real_GTimg

        pred_fake = self.netD(fake_B)
        # fake_AB = torch.cat((self.real_input, self.fake_B), 1)
        # pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake,
                                            True) * self.opt.gan_weight

        #print(self.mask_global.shape,self.real_GTimg.shape,self.fake_B.shape)

        self.output_comp = self.mask_global * self.real_GTimg + (
            1 - self.mask_global) * self.fake_B
        # if self.opt.gan_type == 'wgan_gp':
        #     self.loss_G_GAN = -torch.mean(pred_fake)
        # else:
        #     if self.opt.gan_type in ['vanilla', 'lsgan']:
        #         self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight
        #
        #     elif self.opt.gan_type == 're_s_gan':
        #         pred_real = self.netD (real_B)
        #         self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight
        #
        #     elif self.opt.gan_type == 're_avg_gan':
        #         self.pred_real = self.netD(real_B)
        #         self.loss_G_GAN =  (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \
        #                        + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2.
        #         self.loss_G_GAN *=  self.opt.gan_weight

        # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
        self.loss_G_L1, self.loss_G_L1_m = 0, 0
        self.loss_G_L1 += self.criterionL1(self.fake_B,
                                           self.real_GTimg) * self.opt.lambda_A
        # calcuate mask construction loss
        # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
        # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
        # if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
        #     mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
        #                                         self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
        #     mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
        #                                 self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
        #     # Using Discounting L1 loss
        #     self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G

        self.loss_hole = self.criterionL1(
            (1 - self.mask_global) * self.fake_B,
            (1 - self.mask_global) * self.real_GTimg)
        self.loss_valid = self.criterionL1(self.mask_global * self.fake_B,
                                           self.mask_global * self.real_GTimg)

        self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN

        # Then, add TV loss
        self.loss_tv = self.tv_criterion(self.fake_B.float())

        # Finally, add style loss
        vgg_ft_fakeB = self.vgg16_extractor(fake_B)
        vgg_ft_realB = self.vgg16_extractor(real_B)
        self.loss_style = 0
        self.loss_content = 0

        for i in range(3):
            self.loss_style += self.criterionL2_style_loss(
                util.gram_matrix(vgg_ft_fakeB[i]),
                util.gram_matrix(vgg_ft_realB[i]))
            self.loss_content += self.criterionL2_content_loss(
                vgg_ft_fakeB[i], vgg_ft_realB[i])

        self.loss_style *= self.opt.style_weight
        self.loss_content *= self.opt.content_weight
        self.loss_hole *= 10.0
        self.loss_valid *= 1.0

        self.loss_G += (self.loss_valid + self.loss_hole + self.loss_style +
                        self.loss_content + self.loss_tv)

        self.loss_G.backward()
Пример #8
0
    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        lambda_feature = 750
        lambda_style = 75

        # Identity loss
        if lambda_idt > 0:
            # C * H * W = 7 * 7 * 512
            CHW = 512 * 7 * 7

            #1. Loss idt A

            self.idt_A = self.netG_A(self.real_B)

            idt_A_features = self.vgg16.features(self.idt_A).cuda()
            real_B_features = self.vgg16.features(self.real_B).cuda()

            #print(idt_A_features.size())
            #print(real_B_features.size())

            distance = torch.dist(idt_A_features, real_B_features, 2)

            gramA = gram_matrix(idt_A_features)
            gramB = gram_matrix(real_B_features)

            self.loss_feature_reconstruction_A = (
                1 / CHW) * distance * lambda_feature
            self.loss_style_reconstruction_A = torch.norm(gramA -
                                                          gramB) * lambda_style
            self.loss_idt_A = ((self.loss_feature_reconstruction_A +
                                self.loss_style_reconstruction_A) * lambda_B *
                               lambda_idt) / 30

            #2. Loss idt B

            self.idt_B = self.netG_B(self.real_A)

            idt_B_features = self.vgg16.features(self.idt_B).cuda()
            real_A_features = self.vgg16.features(self.real_A).cuda()

            distance = torch.dist(idt_B_features, real_A_features, 2)

            gramB = gram_matrix(idt_B_features)
            gramA = gram_matrix(real_A_features)

            self.loss_feature_reconstruction_B = (
                1 / CHW) * distance * lambda_feature
            self.loss_style_reconstruction_B = torch.norm(gramA -
                                                          gramB) * lambda_style
            self.loss_idt_B = ((self.loss_feature_reconstruction_B +
                                self.loss_style_reconstruction_B) * lambda_A *
                               lambda_idt) / 30

        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()
    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # GAN loss D_idtA_fakeB(G_A(A)) + D_idtA_fakeB(G_A(B)),  # like (gradient inverse ?)
        self.loss_G_A_idtA_fakeB = (self.criterionGAN(self.netD_idtA_fakeB(self.fake_B), True) + \
                                    self.criterionGAN(self.netD_idtA_fakeB(self.idt_A), False)) * 10
        # GAN loss D_idtB_fakeA(G_B(B)) + D_idtB_fakeA(G_B(A)),  # like (gradient inverse ?)
        self.loss_G_B_idtB_fakeA = (self.criterionGAN(self.netD_idtB_fakeA(self.fake_A), True) + \
                                    self.criterionGAN(self.netD_idtB_fakeA(self.idt_B), False)) * 10
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # netG_B perceptual content loss, inputs (fake_B / real_B / real_A) -> outputs (rec_A / fake_A / idt_B)
        vgg = Vgg16(requires_grad=False).to(self.device)
        normalize = lambda input: util.normalize_batch(((input+1)*127.5).expand(-1,3,-1,-1))  # [-1,1] -> [0,255] -> [vgg normalized], [N,1,H,W]->[H,3,H,W]
        w_content = 1e0
        # MSELoss(input, target), the target should be (requires_grad=False), otherwise it will encourage all 0 predictions.
        self.loss_perceptual_content_rec_A = w_content * torch.nn.MSELoss()(vgg(normalize(self.rec_A)).relu1_2, \
                                                                        vgg(normalize(self.real_A)).relu1_2)  # fake_B / real_A
        self.loss_perceptual_content_fake_A = w_content * torch.nn.MSELoss()(vgg(normalize(self.fake_A)).relu1_2, \
                                                                        vgg(normalize(self.real_B)).relu1_2)
        self.loss_perceptual_content_idt_B = 0 # 0 if lambda_idt <= 0 else w_content * torch.nn.MSELoss()(vgg(normalize(self.idt_B)).relu1_2, \
        # netG_A perceptual content loss
        self.loss_perceptual_content_rec_B = w_content * torch.nn.MSELoss()(vgg(normalize(self.rec_B)).relu1_2, \
                                                                        vgg(normalize(self.real_B)).relu1_2)  # real_B / fake_A
        self.loss_perceptual_content_fake_B = w_content * torch.nn.MSELoss()(vgg(normalize(self.fake_B)).relu1_2, \
                                                                        vgg(normalize(self.real_A)).relu1_2)
        self.loss_perceptual_content_idt_A = 0 # 0 if lambda_idt <= 0 else w_content * torch.nn.MSELoss()(vgg(normalize(self.idt_A)).relu1_2, \
        #                                                                 vgg(normalize(self.real_A)).relu1_2)

        # netG_B perceptual style loss, transfer style from: real_A
        w_style = 1e5
        gram_style_real_A = [util.gram_matrix(y) for y in vgg(normalize(self.real_A))]  # relu1/2/3/4_3 layers embeddings
        gram_style_real_B = [util.gram_matrix(y) for y in vgg(normalize(self.real_B))]  # relu1/2/3/4_3 layers embeddings

        def calc_style_loss(x, gram_name='real_A'):
            style_loss = 0
            n_batch = len(x)
            if 'real_A' in gram_name:
                gram_style = gram_style_real_A
            elif 'real_B' in gram_name:
                gram_style = gram_style_real_B
            features_y = vgg(normalize(x))
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = util.gram_matrix(ft_y)
                style_loss += torch.nn.MSELoss()(gm_y, gm_s[:n_batch, :, :])
            return style_loss

        self.loss_perceptual_style_rec_A =  0 # w_style * calc_style_loss(self.rec_A)
        self.loss_perceptual_style_fake_A = w_style * calc_style_loss(self.fake_A, 'real_A')
        self.loss_perceptual_style_idt_B =  0 # 0 if lambda_idt <= 0 else w_style * calc_style_loss(self.idt_B)
        # netG_A perceptual style loss
        self.loss_perceptual_style_rec_B =  0 # w_style * calc_style_loss(self.rec_A)
        self.loss_perceptual_style_fake_B = 0  #w_style * calc_style_loss(self.fake_B, 'real_B')
        self.loss_perceptual_style_idt_A =  0 # 0 if lambda_idt <= 0 else w_style * calc_style_loss(self.idt_B)
        # SSIM loss
        w_ssim = 1.0
        self.loss_ssim_A = -1 * pytorch_ssim.SSIM()(self.rec_A, self.real_A) * w_ssim
        self.loss_ssim_B = -1 * pytorch_ssim.SSIM()(self.rec_B, self.real_B) * w_ssim
        # combined loss
        self.loss_G = 0 \
                    + self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B \
                    + self.loss_idt_A + self.loss_idt_B \
                    # + self.loss_G_A_idtA_fakeB \
                    # + self.loss_perceptual_content_rec_A \
                    # + self.loss_G_B_idtB_fakeA \
                    # + self.loss_perceptual_content_rec_B \
                    # + self.loss_ssim_A + self.loss_ssim_B
                    # + self.loss_perceptual_style_rec_A + self.loss_perceptual_style_fake_A + self.loss_perceptual_style_idt_B \
                    # + self.loss_perceptual_content_rec_A + self.loss_perceptual_content_fake_A + self.loss_perceptual_content_idt_B \
                    # + self.loss_perceptual_style_rec_B + self.loss_perceptual_style_fake_B + self.loss_perceptual_style_idt_A \
                    # + self.loss_perceptual_content_rec_B + self.loss_perceptual_content_fake_B + self.loss_perceptual_content_idt_A \
        self.loss_G.backward()