예제 #1
0
    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

        ## for object
        fake_A_object = self.fake_A_object_pool.query(self.fake_A_object)
        real_A_object = roi_pooling(self.real_A,self.real_A_bboxes,size=self.object_size)
        loss_D_B_object = self.backward_D_basic(self.netD_B_object, real_A_object, fake_A_object)
        self.loss_D_B_object = loss_D_B_object.data[0]

        # Combine loss
        loss_D_B_total = (5 * loss_D_B + loss_D_B_object)
        # backward
        loss_D_B_total.backward()
예제 #2
0
    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

        ## for object
        fake_B_object = self.fake_B_object_pool.query(self.fake_B_object)
        real_B_object = roi_pooling(self.real_B,
                                    self.real_B_bboxes,
                                    size=self.object_size)
        loss_D_A_object = self.backward_D_basic(self.netD_A_object,
                                                real_B_object, fake_B_object)
        self.loss_D_A_object = loss_D_A_object.data[0]

        # Combine loss
        loss_D_A_total = (loss_D_A + 0.1 * loss_D_A_object)
        # backward
        loss_D_A_total.backward()
예제 #3
0
    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_A_object(G_A(A))
        fake_B_object = roi_pooling(fake_B,
                                    self.real_A_bboxes,
                                    size=self.object_size)
        pred_fake_object = self.netD_B_object(fake_B_object)
        loss_G_A_object = self.criterionGAN(pred_fake_object, True)
        ###TODO: Debug
        # test = transforms.ToPILImage()(fake_B_object.cpu().data.squeeze(0))
        # test.show()

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # GAN loss D_A_object(G_A(A))
        fake_A_object = roi_pooling(fake_A,
                                    self.real_B_bboxes,
                                    size=self.object_size)
        pred_fake_object = self.netD_A_object(fake_A_object)
        loss_G_B_object = self.criterionGAN(pred_fake_object, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
        # Forward cycle object loss
        rec_A_object = roi_pooling(rec_A,
                                   self.real_A_bboxes,
                                   size=self.object_size)
        real_A_object = roi_pooling(self.real_A,
                                    self.real_A_bboxes,
                                    size=self.object_size)
        loss_cycle_A_object = self.criterionCycle(rec_A_object,
                                                  real_A_object) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # Backward cycle object loss
        rec_B_object = roi_pooling(rec_B,
                                   self.real_B_bboxes,
                                   size=self.object_size)
        real_B_object = roi_pooling(self.real_B,
                                    self.real_B_bboxes,
                                    size=self.object_size)
        loss_cycle_B_object = self.criterionCycle(rec_B_object,
                                                  real_B_object) * lambda_B

        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + 0.1 * (
            loss_G_A_object + loss_G_B_object) + 0.1 * (loss_cycle_A_object +
                                                        loss_cycle_B_object)
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data
        self.real_A_object = real_A_object.data
        self.real_B_object = real_B_object.data
        self.fake_B_object = fake_B_object.data
        self.fake_A_object = fake_A_object.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_G_A_object = loss_G_A_object.data[0]
        self.loss_G_B_object = loss_G_B_object.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]
        self.loss_cycle_A_object = loss_cycle_A_object.data[0]
        self.loss_cycle_B_object = loss_cycle_B_object.data[0]