コード例 #1
0
    def set_input(self, input, epoch=0):
        """Unpack input data from the data loader and perform necessary pre-process steps"""
        self.input = input
        self.image_paths = self.input['img_path']
        self.img = input['img']
        self.mask = input['mask']
        self.caption_idx = input['caption_idx']
        self.caption_length = input['caption_len']

        if len(self.gpu_ids) > 0:
            self.img = self.img.cuda(self.gpu_ids[0], True)
            self.mask = self.mask.cuda(self.gpu_ids[0], True)

        # get I_m and I_c for image with mask and complement regions for training
        self.img_truth = self.img * 2 - 1
        self.img_m = self.mask * self.img_truth
        self.img_c =  (1 - self.mask) * self.img_truth

        # get multiple scales image ground truth and mask for training
        self.scale_img = task.scale_pyramid(self.img_truth, self.opt.output_scale)
        self.scale_mask = task.scale_pyramid(self.mask, self.opt.output_scale)

        # About text stuff
        self.text_positive = util.idx_to_caption(
                                    self.ixtoword, self.caption_idx[0].tolist(), self.caption_length[0].item())
        self.word_embeddings, self.sentence_embedding = util.vectorize_captions_idx_batch(
                                                    self.caption_idx, self.caption_length, self.text_encoder)
        self.text_mask = util.lengths_to_mask(self.caption_length, max_length=self.word_embeddings.size(-1))
        self.match_labels = torch.LongTensor(range(len(self.img_m)))
        if len(self.gpu_ids) > 0:
            self.word_embeddings = self.word_embeddings.cuda(self.gpu_ids[0], True)
            self.sentence_embedding = self.sentence_embedding.cuda(self.gpu_ids[0], True)
            self.text_mask = self.text_mask.cuda(self.gpu_ids[0], True)
            self.match_labels = self.match_labels.cuda(self.gpu_ids[0], True)
コード例 #2
0
    def backward_synthesis2real(self):

        # image to image transform
        network._freeze(self.net_img_D, self.net_img2task)
        network._unfreeze(self.net_s2t)

        self.img_s2t_1, self.img_t2t_1, self.img_f_s_1, self.img_f_t_1, size_1 = \
            self.foreward_G_basic(self.net_s2t, self.img_s_1, self.img_t_1)
        self.img_s2t_2, self.img_t2t_2, self.img_f_s_2, self.img_f_t_2, size_2 = \
            self.foreward_G_basic(self.net_s2t, self.img_s_2, self.img_t_2)

        img_real_1 = task.scale_pyramid(self.img_t_1, size_1 - 1)
        img_real_2 = task.scale_pyramid(self.img_t_2, size_2 - 1)

        G_loss = 0
        rec_loss = 0

        for i in range(size_1 - 1):
            rec_loss += self.l1loss(self.img_t2t_1[i], img_real_1[i])
            D_fake = self.net_img_D(self.img_s2t_1[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0)**2)

        for i in range(size_2 - 1):
            rec_loss += self.l1loss(self.img_t2t_2[i], img_real_2[i])
            D_fake = self.net_img_D(self.img_s2t_2[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0)**2)

        self.loss_img_G = G_loss * self.opt.lambda_gan_img
        self.loss_img_rec = rec_loss * self.opt.lambda_rec_img

        total_loss = self.loss_img_G + self.loss_img_rec

        total_loss.backward(retain_graph=True)
コード例 #3
0
    def backward_task(self):

        # self.lab_s_g, self.lab_t_g, self.lab_f_s, self.lab_f_t, size = \
        #     self.foreward_G_basic(self.net_img2task, self.img_s, self.img_t)

        self.output_s_g = self.net_img2task(self.img_s)

        size = len(self.output_s_g)
        self.lab_s_g = self.output_s_g[1:]

        lab_real = task.scale_pyramid(self.lab_s, size - 1)
        task_loss = 0
        for (lab_fake_i, lab_real_i) in zip(self.lab_s_g, lab_real):
            task_loss += self.l1loss(lab_fake_i, lab_real_i)

        self.loss_lab_s = task_loss * self.opt.lambda_rec_lab

        img_real = task.scale_pyramid(self.img_s, size - 1)
        self.loss_lab_smooth_s = task.get_smooth_weight(
            self.lab_s_g, img_real, size - 1) * self.opt.lambda_smooth

        total_loss = self.loss_lab_s + self.loss_lab_smooth_s

        total_loss.backward()
        del total_loss
コード例 #4
0
ファイル: TPOVRModel.py プロジェクト: scriptlee/GAN-PWCNet
    def backward_synthesis2real(self):

        # image to image transform
        network._freeze(self.net_img_D, self.net_img2task)
        network._unfreeze(self.net_s2t)

        self.img_s2t_1, self.img_t2t_1, self.img_f_s_1, self.img_f_t_1, size_1 = \
            self.foreward_G_basic(self.net_s2t, self.img_s_1, self.img_t_1)
        self.img_s2t_2, self.img_t2t_2, self.img_f_s_2, self.img_f_t_2, size_2 = \
            self.foreward_G_basic(self.net_s2t, self.img_s_2, self.img_t_2)


        # self.img_s2t2t_1, self.img_s2t2t_2, self.img_f_s_3, self.img_f_t_3, size_3 = \
        #     self.foreward_G_basic(self.net_s2t, self.img_s2t_1, self.img_s2t_2)

        # print('img_s2t_1', self.img_s2t_1[-1].size())

        # for i in range(1):
        #         self.vis.images(self.img_s2t_1[-1][i].data.cpu().numpy(), win = self.vis1)
        #         self.vis.images(self.img_t2t_1[-1][i].data.cpu().numpy(), win = self.vis2)

        # image GAN loss and reconstruction loss
        img_real_1 = task.scale_pyramid(self.img_t_1, size_1 - 1)
        img_real_2 = task.scale_pyramid(self.img_t_2, size_2 - 1)
        # img_s2t_1_p = task.scale_pyramid(self.img_s2t_1, size_3 - 1)
        # img_s2t_2_p = task.scale_pyramid(self.img_s2t_2, size_3 - 1)
        G_loss = 0
        rec_loss = 0
        s_rec_loss = 0

        for i in range(size_1 - 1):
            rec_loss += self.l1loss(self.img_t2t_1[i], img_real_1[i])
            D_fake = self.net_img_D(self.img_s2t_1[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0) ** 2)

        for i in range(size_2 - 1):
            rec_loss += self.l1loss(self.img_t2t_2[i], img_real_2[i])
            D_fake = self.net_img_D(self.img_s2t_2[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0) ** 2)

        # for i in range(size_3 - 1):
        #     s_rec_loss += self.l1loss(self.img_s2t2t_1[i], img_s2t_1_p[i])
        #     s_rec_loss += self.l1loss(self.img_s2t2t_2[i], img_s2t_2_p[i])
           

        self.loss_img_G = G_loss * self.opt.lambda_gan_img
        self.loss_img_rec = rec_loss * self.opt.lambda_rec_img
        # self.loss_s_rec = s_rec_loss * self.opt.lambda_rec_s

        # print('loss_img_G', self.loss_img_G.data.cpu())
        # print('loss_img_rec', self.loss_img_rec.data.cpu())
        total_loss = self.loss_img_G + self.loss_img_rec

        # total_loss = self.loss_img_G + self.loss_img_rec + self.loss_s_rec

        total_loss.backward(retain_graph = True)
コード例 #5
0
    def backward_G(self):

        lambda_Dehazing = self.opt.lambda_Dehazing

        size = len(self.out)
        clear_imgs = task.scale_pyramid(self.clear_img, size - 1)
        self.loss_S_Dehazing = 0.0
        for (dehazing_img, clear_img) in zip(self.out[1:], clear_imgs):
            self.loss_S_Dehazing += self.criterionDehazing(
                dehazing_img[:self.num, :, :, :], clear_img) * lambda_Dehazing

        # TV LOSS
        self.loss_R2S_Dehazing_TV = self.TVLoss(
            self.r2s_dehazing_img) * self.opt.lambda_Dehazing_TV

        # DC LOSS
        self.loss_R2S_Dehazing_DC = DCLoss(
            (self.r2s_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC

        # GAN LOSS
        self.loss_G = self.criterionGAN(self.netD(self.r2s_dehazing_img),
                                        True) * self.opt.lambda_gan
        self.loss_GS_Dehazing = self.loss_S_Dehazing + self.loss_R2S_Dehazing_TV + self.loss_R2S_Dehazing_DC + self.loss_G

        self.loss_GS_Dehazing.backward()
コード例 #6
0
    def backward_translated2depth(self):

        # task network
        fake = self.net_img2task.forward(self.img_s2t[-1])

        size = len(fake)
        self.lab_f_s = fake[0]
        self.lab_s_g = fake[1:]

        #feature GAN loss
        D_fake = self.net_f_D(self.lab_f_s)
        G_loss = 0
        for D_fake_i in D_fake:
            G_loss += torch.mean((D_fake_i - 1.0)**2)
        self.loss_f_G = G_loss * self.opt.lambda_gan_feature

        # task loss
        lab_real = task.scale_pyramid(self.lab_s, size - 1)
        task_loss = 0
        for (lab_fake_i, lab_real_i) in zip(self.lab_s_g, lab_real):
            task_loss += self.l1loss(lab_fake_i, lab_real_i)

        self.loss_lab_s = task_loss * self.opt.lambda_rec_lab

        total_loss = self.loss_f_G + self.loss_lab_s

        total_loss.backward()
コード例 #7
0
 def backward_D_image(self):
     size = len(self.img_s2t)
     fake = []
     for i in range(size):
         fake.append(self.fake_img_pool.query(self.img_s2t[i]))
     real = task.scale_pyramid(self.img_t, size)
     self.loss_img_D = self.backward_D_basic(self.net_img_D, real, fake)
コード例 #8
0
    def fill_mask(self):
        """Forward to get the generation results"""
        img_m, img_c, img_truth, mask = self.set_input()
        if self.PaintPanel.iteration < 100:
            with torch.no_grad():
                # encoder process
                distributions, f = self.model.net_E(img_m)
                q_distribution = torch.distributions.Normal(
                    distributions[-1][0], distributions[-1][1])
                #q_distribution = torch.distributions.Normal( torch.zeros_like(distributions[-1][0]), torch.ones_like(distributions[-1][1]))
                z = q_distribution.sample()

                # decoder process
                scale_mask = task.scale_pyramid(mask, 4)
                self.img_g, self.atten = self.model.net_G(
                    z,
                    f_m=f[-1],
                    f_e=f[2],
                    mask=scale_mask[0].chunk(3, dim=1)[0])
                self.img_out = (1 -
                                mask) * self.img_g[-1].detach() + mask * img_m

                # get score
                score = self.model.net_D(self.img_out).mean()
                self.label_6.setText(str(round(score.item(), 3)))
                self.PaintPanel.iteration += 1

        self.show_result_flag = True
        self.show_result()
コード例 #9
0
    def backward_G(self):

        lambda_Dehazing = self.opt.lambda_Dehazing

        size = len(self.out)
        clear_imgs = task.scale_pyramid(self.clear_img, size - 1)
        self.loss_S2R_Dehazing = 0.0
        for (dehazing_img, clear_img) in zip(self.out[1:], clear_imgs):
            self.loss_S2R_Dehazing += self.criterionDehazing(
                dehazing_img[:self.num, :, :, :], clear_img) * lambda_Dehazing

        # if iter % 2000 == 1 and iter > 2000:
        # 	self.opt.lambda_Dehazing_DC *= self.opt.unlabel_decay
        # 	self.opt.lambda_Dehazing_TV *= self.opt.unlabel_decay
        # 	print('unlabel loss decay {}, is {}'.format(self.opt.unlabel_decay, self.opt.lambda_DC))

        # self.loss_R_Dehazing_TV = torch.Tensor([0.0])
        # self.loss_R_Dehazing_DC = torch.Tensor([0.0])

        # TV LOSS
        # self.loss_R_Dehazing_TV = TVLossL1(self.r_dehazing_img) * self.opt.lambda_Dehazing_TV
        self.loss_R_Dehazing_TV = self.TVLoss(
            self.r_dehazing_img) * self.opt.lambda_Dehazing_TV

        # DC LOSS
        self.loss_R_Dehazing_DC = DCLoss(
            (self.r_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC
        # GAN LOSS
        self.loss_G = self.criterionGAN(self.netD(self.r_dehazing_img),
                                        True) * self.opt.lambda_gan
        self.loss_GR_Dehazing = self.loss_S2R_Dehazing + self.loss_R_Dehazing_TV + self.loss_R_Dehazing_DC + self.loss_G

        self.loss_GR_Dehazing.backward()
コード例 #10
0
    def validation_target(self):

        lab_real = task.scale_pyramid(self.lab_t, len(self.lab_t_g))
        task_loss = 0
        for (lab_fake_i, lab_real_i) in zip(self.lab_t_g, lab_real):
            task_loss += task.rec_loss(lab_fake_i, lab_real_i)

        self.loss_lab_t = task_loss * self.opt.lambda_rec_lab
コード例 #11
0
    def backward_D_image(self):
        network._freeze(self.net_s2t)
        network._unfreeze(self.net_img_D)
        size_1 = len(self.img_s2t_1)
        size_2 = len(self.img_s2t_2)
        fake_1 = []
        fake_2 = []
        for i in range(size_1):
            fake_1.append(self.fake_img_pool.query(self.img_s2t_1[i]))
        real_1 = task.scale_pyramid(self.img_t_1, size_1)
        for i in range(size_2):
            fake_2.append(self.fake_img_pool.query(self.img_s2t_2[i]))
        real_2 = task.scale_pyramid(self.img_t_2, size_2)

        self.loss_img_D = self.backward_D_basic(
            self.net_img_D, real_1, fake_1) + self.backward_D_basic(
                self.net_img_D, real_2, fake_2)
コード例 #12
0
    def set_input(self, input, epoch=0):
        """Unpack input data from the data loader and perform necessary pre-process steps"""
        self.input = input
        self.image_paths = self.input['img_path']
        self.img = input['img']
        self.mask = input['mask']

        if len(self.gpu_ids) > 0:
            self.img = self.img.cuda(self.gpu_ids[0], async=True)
            self.mask = self.mask.cuda(self.gpu_ids[0], async=True)

        # get I_m and I_c for image with mask and complement regions for training
        self.img_truth = self.img * 2 - 1
        self.img_m = self.mask * self.img_truth
        self.img_c = (1 - self.mask) * self.img_truth

        # get multiple scales image ground truth and mask for training
        self.scale_img = task.scale_pyramid(self.img_truth, self.opt.output_scale)
        self.scale_mask = task.scale_pyramid(self.mask, self.opt.output_scale)
コード例 #13
0
    def fill_mask(self):
        """Forward to get the generation results"""
        img_m, img_c, img_truth, mask, text_idx, text_len = self.set_input()
        if self.comboBox.currentIndex() == 0:
            return
        if text_len < 1:
            self.textEdit.setText(
                'Input some words about this bird or the bird you want.')
            return
        print(self.textEdit.toPlainText())
        if self.PaintPanel.iteration < 100:
            print(self.PaintPanel.iteration)
            with torch.no_grad():
                # encoder process
                word_embeddings, sentence_embedding = util.vectorize_captions_idx_batch(
                    text_idx, text_len, self.model.text_encoder)
                img_mask = torch.ones_like(img_m)
                img_mask[img_m == 0.] = 0.
                distributions, f, f_text = self.model.net_E(
                    img_m, sentence_embedding, word_embeddings, None, img_mask)
                variation_factor = 1. if self.checkBox.isChecked() else 0.
                q_distribution = torch.distributions.Normal(
                    distributions[-1][0],
                    distributions[-1][1] * variation_factor)
                #q_distribution = torch.distributions.Normal( torch.zeros_like(distributions[-1][0]), torch.ones_like(distributions[-1][1]))
                z = q_distribution.sample()

                # decoder process
                scale_mask = task.scale_pyramid(mask, 4)
                self.img_g, self.atten = self.model.net_G(
                    z, f_text, f_e=f[2], mask=scale_mask[0].chunk(3, dim=1)[0])
                self.img_out = (1 -
                                mask) * self.img_g[-1].detach() + mask * img_m

                # get score
                l1, PSNR, TV = compute_errors(
                    util.tensor2im(self.img_truth),
                    util.tensor2im(self.img_out.detach()))

                self.label_6.setText(str(PSNR))

                self.PaintPanel.iteration += 1

        self.show_result_flag = True
        import ipdb
        ipdb.set_trace()
        self.show_result()
コード例 #14
0
    def backward_real2depth(self):

        # image2depth
        fake = self.net_img2task.forward(self.img_t)
        size = len(fake)

        # Gan depth
        self.lab_f_t = fake[0]
        self.lab_t_g = fake[1:]

        img_real = task.scale_pyramid(self.img_t, size - 1)
        self.loss_lab_smooth = task.get_smooth_weight(
            self.lab_t_g, img_real, size - 1) * self.opt.lambda_smooth

        total_loss = self.loss_lab_smooth

        total_loss.backward()
コード例 #15
0
    def backward_synthesis2real(self):

        # image to image transform
        self.img_s2t, self.img_t2t, self.img_f_s, self.img_f_t, size = \
            self.foreward_G_basic(self.net_s2t, self.img_s, self.img_t)

        # image GAN loss and reconstruction loss
        img_real = task.scale_pyramid(self.img_t, size - 1)
        G_loss = 0
        rec_loss = 0
        for i in range(size - 1):
            rec_loss += self.l1loss(self.img_t2t[i], img_real[i])
            D_fake = self.net_img_D(self.img_s2t[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0)**2)

        self.loss_img_G = G_loss * self.opt.lambda_gan_img
        self.loss_img_rec = rec_loss * self.opt.lambda_rec_img

        total_loss = self.loss_img_G + self.loss_img_rec

        total_loss.backward(retain_graph=True)
コード例 #16
0
    def backward_G(self):

        lambda_Dehazing = self.opt.lambda_Dehazing
        lambda_Dehazing_Con = self.opt.lambda_Dehazing_Con
        lambda_gan_feat = self.opt.lambda_gan_feat
        lambda_idt = self.opt.lambda_identity
        lambda_S = self.opt.lambda_S
        lambda_R = self.opt.lambda_R

        # =========================== synthetic ==========================
        self.img_s2r = self.netS2R(self.syn_haze_img)
        self.idt_S = self.netR2S(self.syn_haze_img)
        self.s_rec_img = self.netR2S(self.img_s2r)
        self.out_r = self.netR_Dehazing(self.img_s2r)
        self.out_s = self.netS_Dehazing(self.syn_haze_img)
        self.s2r_dehazing_feat = self.out_r[0]
        self.s_dehazing_feat = self.out_s[0]
        self.s2r_dehazing_img = self.out_r[-1]
        self.s_dehazing_img = self.out_s[-1]
        self.loss_G_S2R = self.criterionGAN(self.netD_R(self.img_s2r), True)
        self.loss_G_Rfeat = self.criterionGAN(
            self.netD_Rfeat(self.s2r_dehazing_feat), True) * lambda_gan_feat
        self.loss_cycle_S = self.criterionCycle(self.s_rec_img,
                                                self.syn_haze_img) * lambda_S
        self.loss_idt_S = self.criterionIdt(
            self.idt_S, self.syn_haze_img) * lambda_S * lambda_idt
        size = len(self.out_s)
        self.loss_S_Dehazing = 0.0
        clear_imgs = task.scale_pyramid(self.clear_img, size - 1)
        for (s_dehazing_img, clear_img) in zip(self.out_s[1:], clear_imgs):
            self.loss_S_Dehazing += self.criterionDehazing(
                s_dehazing_img, clear_img) * lambda_Dehazing
        self.loss_S2R_Dehazing = 0.0
        for (s2r_dehazing_img, clear_img) in zip(self.out_r[1:], clear_imgs):
            self.loss_S2R_Dehazing += self.criterionDehazing(
                s2r_dehazing_img, clear_img) * lambda_Dehazing
        self.loss = self.loss_G_S2R + self.loss_G_Rfeat + self.loss_cycle_S + self.loss_idt_S + self.loss_S_Dehazing + self.loss_S2R_Dehazing
        self.loss.backward()

        # ============================= real =============================
        self.img_r2s = self.netR2S(self.real_haze_img)
        self.idt_R = self.netS2R(self.real_haze_img)
        self.r_rec_img = self.netS2R(self.img_r2s)
        self.out_s = self.netS_Dehazing(self.img_r2s)
        self.out_r = self.netR_Dehazing(self.real_haze_img)
        self.r_dehazing_feat = self.out_r[0]
        self.r2s_dehazing_feat = self.out_s[0]
        self.r_dehazing_img = self.out_r[-1]
        self.r2s_dehazing_img = self.out_s[-1]
        self.loss_G_R2S = self.criterionGAN(self.netD_S(self.img_r2s), True)
        self.loss_G_Sfeat = self.criterionGAN(
            self.netD_Sfeat(self.r2s_dehazing_feat), True) * lambda_gan_feat
        self.loss_cycle_R = self.criterionCycle(self.r_rec_img,
                                                self.real_haze_img) * lambda_R
        self.loss_idt_R = self.criterionIdt(
            self.idt_R, self.real_haze_img) * lambda_R * lambda_idt

        # TV LOSS

        self.loss_R2S_Dehazing_TV = self.TVLoss(
            self.r2s_dehazing_img) * self.opt.lambda_Dehazing_TV
        self.loss_R_Dehazing_TV = self.TVLoss(
            self.r_dehazing_img) * self.opt.lambda_Dehazing_TV

        # DC LOSS

        self.loss_R2S_Dehazing_DC = DCLoss(
            (self.r2s_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC
        self.loss_R_Dehazing_DC = DCLoss(
            (self.r_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC

        # dehazing consistency
        self.loss_Dehazing_Con = 0.0
        for (out_s1, out_r2) in zip(self.out_s, self.out_r):
            self.loss_Dehazing_Con += self.criterionCons(
                out_s1, out_r2) * lambda_Dehazing_Con

        self.loss_G = self.loss_G_R2S + self.loss_G_Sfeat + self.loss_cycle_R + self.loss_idt_R + self.loss_R2S_Dehazing_TV \
             + self.loss_R_Dehazing_TV + self.loss_R2S_Dehazing_DC + self.loss_R_Dehazing_DC + self.loss_Dehazing_Con
        self.loss_G.backward()
        self.real_dehazing_img = (self.r_dehazing_img +
                                  self.r2s_dehazing_img) / 2.0
        self.syn_dehazing_img = (self.s_dehazing_img +
                                 self.s2r_dehazing_img) / 2.0