def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Пример #3
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
                                   opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        # changing the default values to match the pix2pix paper
        # (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(pool_size=0)
        parser.set_defaults(no_lsgan=True)
        parser.set_defaults(norm='batch')
        parser.set_defaults(dataset_mode='aligned')
        parser.set_defaults(which_model_netG='unet_256')
        if is_train:
            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG(self.real_A)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        # update D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
Пример #6
0
class VIGANModel(BaseModel):
    def name(self):
        return 'VIGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                     opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                    opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.AE, 'AE', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionAE = torch.nn.MSELoss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(self.AE.parameters(),
                                                 lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE_GA_GB = torch.optim.Adam(
                itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            networks.print_network(self.AE)
            print('-----------------------------------------------')

    def set_input(self, images_a, images_b):
        input_A =images_a
        input_B =images_b

        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)


    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B  = self.netG_A.forward(self.fake_A)

        # Autoencoder loss: fakeA
        self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        # Autoencoder loss: fakeB
        AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B)




    #get image pathss
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B =  self.backward_D_basic(self.netD_B, self.real_A, fake_A)


    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    ############################################################################
    # Define backward function for VIGAN
    ############################################################################

    def backward_AE_pretrain(self):
        # Autoencoder loss
        AErealA, AErealB = self.AE.forward(self.real_A, self.real_B)
        self.loss_AE_pre = self.criterionAE(AErealA, self.real_A) + self.criterionAE(AErealB, self.real_A)
        self.loss_AE_pre.backward()

    def backward_AE(self):

        # fake data
        self.fake_B = self.netG_A.forward(self.real_A)
        self.fake_A = self.netG_B.forward(self.real_B)

        # Autoencoder loss: fakeA
        AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        self.loss_AE_fA_rB = (
                             self.criterionAE(AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1

        # Autoencoder loss: fakeB
        AErealA, AEfakeB = self.AE.forward(self.real_A, self.fake_B)
        self.loss_AE_rA_fB = (
                             self.criterionAE(AErealA, self.real_A) + self.criterionAE(AEfakeB, self.real_B)) * 1

        # combined loss
        self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) * 0.5
        self.loss_AE.backward()


    # input is vector
    def backward_D_A_AE(self):
        fake_B = self.AEfakeB
        self.loss_D_A_AE = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B_AE(self):
        fake_A = self.AEfakeA
        self.loss_D_B_AE =  self.backward_D_basic(self.netD_B, self.real_A, fake_A)


    def backward_AE_GA_GB(self):

        lambda_C = self.opt.lambda_C
        lambda_D = self.opt.lambda_D

        # fake data
        # G_A(A)
        self.fake_B = self.netG_A.forward(self.real_A)
        # G_B(B)
        self.fake_A = self.netG_B.forward(self.real_B)

        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A_AE = self.criterionCycle(self.rec_A, self.real_A)
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B_AE = self.criterionCycle(self.rec_B, self.real_B)

        # Autoencoder loss: fakeA
        self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        self.loss_AE_fA_rB = (self.criterionAE(self.AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1

        # Autoencoder loss: fakeB
        AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B)
        self.loss_AE_rA_fB = (self.criterionAE(AErealA, self.real_A) + self.criterionAE(self.AEfakeB, self.real_B)) * 1
        self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB)

        # D loss
        pred_fake = self.netD_A.forward(self.AEfakeB)
        self.loss_AE_GA = self.criterionGAN(pred_fake, True)
        pred_fake = self.netD_B.forward(self.AEfakeA)
        self.loss_AE_GB = self.criterionGAN(pred_fake, True)

        self.loss_AE_GA_GB = lambda_C * ( self.loss_AE_GA + self.loss_AE_GB) + \
                             lambda_D * self.loss_AE + 1 * (self.loss_cycle_A_AE + self.loss_cycle_B_AE)
        self.loss_AE_GA_GB.backward()


    #########################################################################################################

    def optimize_parameters_pretrain_cycleGAN(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    ############################################################################
    # Define optimize function for VIGAN
    ############################################################################
    def optimize_parameters_pretrain_AE(self):
        # forward
        self.forward()
        # AE
        self.optimizer_AE.zero_grad()
        self.backward_AE_pretrain()
        self.optimizer_AE.step()

    def optimize_parameters(self):
        # forward
        self.forward()

        # AE+G_A+G_B
        for i in range(2):
            self.optimizer_AE_GA_GB.zero_grad()
            self.backward_AE_GA_GB()
            self.optimizer_AE_GA_GB.step()

        for i in range(1):
            # D_A
            self.optimizer_D_A_AE.zero_grad()
            self.backward_D_A_AE()
            self.optimizer_D_A_AE.step()
            # D_B
            self.optimizer_D_B_AE.zero_grad()
            self.backward_D_B_AE()
            self.optimizer_D_B_AE.step()

    ############################################################################################
    # Get errors for visualization
    ############################################################################################
    def get_current_errors_cycle(self):
        AE_D_A = self.loss_D_A.data[0]
        AE_G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        AE_D_B = self.loss_D_B.data[0]
        AE_G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A),
                                ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B)])

    def get_current_errors(self):
        D_A = self.loss_D_A_AE.data[0]
        G_A = self.loss_AE_GA.data[0]
        Cyc_A = self.loss_cycle_A_AE.data[0]
        D_B = self.loss_D_B_AE.data[0]
        G_B = self.loss_AE_GB.data[0]
        Cyc_B = self.loss_cycle_B_AE.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A  = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B  = util.tensor2im(self.rec_B.data)

        AE_fake_A = util.tensor2im(self.AEfakeA.view(1,1,28,28).data)
        AE_fake_B = util.tensor2im(self.AEfakeB.view(1,1,28,28).data)


        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.AE, 'AE', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'D_C_A', 'D_C_B', 'G_C', 'cycle_C', 'cycle_C', 'idt_C_A', 'idt_C_B'
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        visual_names_C = [
            'real_C', 'fake_C_A', 'fake_C_B', 'rec_C_A', 'rec_C_B'
        ]
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')
            visual_names_C.append('idt_C_A')
            visual_names_C.append('idt_C_B')

        self.visual_names = visual_names_A + visual_names_B + visual_names_C
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = [
                'G_A', 'G_B', 'D_A', 'D_B', 'G_C_A', 'G_C_B', 'D_C'
            ]
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B', 'G_C_A', 'G_C_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        self.netG_C_A = networks.define_G(opt.input_nc, opt.input_nc, opt.ngf,
                                          opt.netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, self.gpu_ids)
        self.netG_C_B = networks.define_G(opt.output_nc, opt.output_nc,
                                          opt.ngf, opt.netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

            self.netD_C = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            self.fake_C_A_pool = ImagePool(opt.pool_size)
            self.fake_C_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters(),
                self.netG_C_A.parameters(), self.netG_C_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters(),
                self.netD_C.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Пример #8
0
class gea_ganModel(BaseModel):
    def name(self):
        return 'gea_ganModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.batchSize = opt.batchSize
        self.fineSize = opt.fineSize

        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize, opt.fineSize)

        if self.opt.rise_sobelLoss:
            self.sobelLambda = 0
        else:
            self.sobelLambda = self.opt.lambda_sobel

        # load/define networks

        which_netG = opt.which_model_netG
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      which_netG, opt.norm, opt.use_dropout,
                                      self.gpu_ids)
        if self.isTrain:

            self.D_channel = opt.input_nc + opt.output_nc
            use_sigmoid = opt.no_lsgan

            self.netD = networks.define_D(self.D_channel, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid, self.gpu_ids)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)
        if not self.isTrain:
            self.netG.eval()

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            if self.opt.labelSmooth:
                self.criterionGAN = networks.GANLoss_smooth(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG)
            networks.print_network(self.netD)
            print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']

        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B

        self.fake_sobel = networks.sobelLayer(self.fake_B)
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))

        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real

        self.real_sobel = networks.sobelLayer(self.real_B).detach()
        real_AB = torch.cat((self.real_A, self.real_B), 1)

        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator

        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B

        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_sobelL1 = self.criterionL1(
            self.fake_sobel, self.real_sobel) * self.sobelLambda
        self.loss_G += self.loss_sobelL1

        self.loss_G.backward()

    def optimize_parameters(self):

        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):

        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('G_sobelL1', self.loss_sobelL1.data[0]),
                            ('D_GAN', self.loss_D.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2array(self.real_A.data)
        fake_B = util.tensor2array(self.fake_B.data)
        real_B = util.tensor2array(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def get_current_img(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr

    def update_sobel_lambda(self, epochNum):
        self.sobelLambda = self.opt.lambda_sobel / 20 * epochNum
        print('update sobel lambda: %f' % (self.sobelLambda))
Пример #9
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,  
                                   opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0])
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, 
                                   opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0])

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)

        # If this is training phase
        if self.isTrain:
            use_sigmoid = opt.no_lsgan # do not use least square GAN by default
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)

        # If this is non-training phase/continue training phase
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            # build up so called history pool
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gpu_ids=opt.gpu_ids)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            if opt.use_prcp:
                self.criterionPrcp = networks.PrcpLoss(opt.weight_path, opt.bias_path, opt.perceptual_level, tensor=self.Tensor, gpu_ids=opt.gpu_ids)

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True) # no back propagation
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B) # A recover

        self.real_B = Variable(self.input_B, volatile=True) # no back propagation
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A) # B recover

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        # stop back propagate this part of loss back to generator, as we only care about discriminator here
        pred_fake = netD.forward(fake.detach()) 
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)

        # Cycle loss
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()

        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
                                   opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])
                            ])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #12
0
class StackGANModel(BaseModel):
    def name(self):
        return 'StackGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # define tensors
        self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc,
                                   opt.fineSize, opt.fineSize)
        self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)

        self.input_base = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)


        # load/define networks
        if self.opt.conv3d:
            # one layer for considering a conv filter for each of the 26 channels
            self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids)

        # Generator of the GlyphNet
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)

        
        #Generator of the OrnaNet as an Encoder and a Decoder
        self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids)
        
        self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids)
                            

        if self.opt.conditional:
            # not applicable for non-conditional case
            use_sigmoid = opt.no_lsgan
            if opt.which_model_preNet != 'none':
                self.preNet_A = networks.define_preNet(self.opt.input_nc_1+self.opt.output_nc_1, self.opt.input_nc_1+self.opt.output_nc_1, which_model_preNet=opt.which_model_preNet,norm=opt.norm, gpu_ids=self.gpu_ids)

            nif = opt.input_nc_1+opt.output_nc_1

            
            netD_norm = opt.norm

            self.netD1 = networks.define_D(nif, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, netD_norm, use_sigmoid, True, self.gpu_ids)



        if self.isTrain:
            if self.opt.conv3d:
                 self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)

            self.load_network(self.netG, 'G', opt.which_epoch)

            if self.opt.print_weights:
                for key in self.netE1.state_dict().keys():
                    print key, 'random_init, mean,std:', torch.mean(self.netE1.state_dict()[key]),torch.std(self.netE1.state_dict()[key])
                for key in self.netDE1.state_dict().keys():
                    print key, 'random_init, mean,std:', torch.mean(self.netDE1.state_dict()[key]),torch.std(self.netDE1.state_dict()[key])


        if not self.isTrain:
            print "Load generators from their pretrained models..."
            if opt.no_Style2Glyph:
                if self.opt.conv3d:
                     self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)
            else:
                if self.opt.conv3d:
                     self.load_network(self.netG_3d, 'G_3d', str(int(opt.which_epoch)+int(opt.which_epoch1)))
                self.load_network(self.netG, 'G', str(int(opt.which_epoch)+int(opt.which_epoch1)))
                self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1)))
                self.load_network(self.netDE1, 'DE1', str(int(opt.which_epoch1)))
                self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1)))
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)


        if self.isTrain:
            if opt.continue_train:
                print "Load StyleNet from its pretrained model..."
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)


        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
        if self.isTrain:
            self.fake_AB1_pool = ImagePool(opt.pool_size)

            self.old_lr = opt.lr
            # define loss functions
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()


            # initialize optimizers
            if self.opt.conv3d:
                 self.optimizer_G_3d = torch.optim.Adam(self.netG_3d.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.which_model_preNet != 'none':
                self.optimizer_preA = torch.optim.Adam(self.preNet_A.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))

                                            
            self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))


            print('---------- Networks initialized -------------')
            if self.opt.conv3d:
                networks.print_network(self.netG_3d)
            networks.print_network(self.netG)
            networks.print_network(self.netE1)
            networks.print_network(self.netDE1)
            if opt.which_model_preNet != 'none':
                networks.print_network(self.preNet_A)

            networks.print_network(self.netD1)
            print('-----------------------------------------------')

            self.initial = True

    def set_input(self, input):
        input_A0 = input['A']
        input_B0 = input['B']
        self.input_A0.resize_(input_A0.size()).copy_(input_A0)
        self.input_B0.resize_(input_B0.size()).copy_(input_B0)
        self.image_paths = input['B_paths']


        if self.opt.base_font:
            input_base = input['A_base']
            self.input_base.resize_(input_base.size()).copy_(input_base)
        
            b,c,m,n = self.input_base.size()
            
            real_base = self.Tensor(self.opt.output_nc,self.opt.input_nc_1, m,n)
            for batch in range(self.opt.output_nc):
                if not self.opt.rgb_in and self.opt.rgb_out:
                    real_base[batch,0,:,:] = self.input_base[0,batch,:,:]
                    real_base[batch,1,:,:] = self.input_base[0,batch,:,:]
                    real_base[batch,2,:,:] = self.input_base[0,batch,:,:]
            
            self.real_base = Variable(real_base, requires_grad=False)

        if self.opt.isTrain:

            self.id_ = {}
            self.obs = []
            for i,im in enumerate(self.image_paths):
                self.id_[int(im.split('/')[-1].split('.png')[0].split('_')[-1])]=i
                self.obs += [int(im.split('/')[-1].split('.png')[0].split('_')[-1])]
            for i in list(set(range(self.opt.output_nc))-set(self.obs)):
                self.id_[i] = np.random.randint(low=0, high=len(self.image_paths))

            self.num_disc = self.opt.output_nc +1


    def all2observed(self, tensor_all):
        b,c,m,n = self.real_A0.size()

        self.out_id = self.obs
        tensor_gt = self.Tensor(b,self.opt.input_nc_1, m,n)
        for batch in range(b):
            if not self.opt.rgb_in and self.opt.rgb_out:
                tensor_gt[batch,0,:,:] = tensor_all.data[batch,self.out_id[batch],:,:]
                tensor_gt[batch,1,:,:] = tensor_all.data[batch,self.out_id[batch],:,:]
                tensor_gt[batch,2,:,:] = tensor_all.data[batch,self.out_id[batch],:,:]
            else:
                #TODO
                tensor_gt[batch,:,:,:] = tensor_all.data[batch,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:]
        return tensor_gt

    def forward0(self):
        self.real_A0 = Variable(self.input_A0)
        if self.opt.conv3d:
            self.real_A0_indep = self.netG_3d.forward(self.real_A0.unsqueeze(2))
            self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2))
        else:
            self.fake_B0 = self.netG.forward(self.real_A0)
        if self.initial:
            if self.opt.orna:
                self.fake_B0_init = self.real_A0
            else:
                self.fake_B0_init = self.fake_B0



                                
    def forward1(self, inp_grad=False):
        b,c,m,n = self.real_A0.size()
        
        self.batch_ = b
        self.out_id = self.obs
        real_A1 = self.Tensor(self.opt.output_nc,self.opt.input_nc_1, m,n)
        if self.opt.orna:
            inp_orna = self.fake_B0_init
        else:
            inp_orna = self.fake_B0

        for batch in range(self.opt.output_nc):
            if not self.opt.rgb_in and self.opt.rgb_out:
                real_A1[batch,0,:,:] = inp_orna.data[self.id_[batch],batch,:,:]
                real_A1[batch,1,:,:] = inp_orna.data[self.id_[batch],batch,:,:]
                real_A1[batch,2,:,:] = inp_orna.data[self.id_[batch],batch,:,:]
            else:
                #TODO
                real_A1[batch,:,:,:] = inp_orna.data[batch,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:]
        if self.initial:
            self.real_A1_init = Variable(real_A1, requires_grad=False)
            self.initial = False

        self.real_A1_s = Variable(real_A1, requires_grad=inp_grad)
        self.real_A1 = self.real_A1_s

        self.fake_B1_emb = self.netE1.forward(self.real_A1)
        self.fake_B1 = self.netDE1.forward(self.fake_B1_emb)
        self.real_B1 = Variable(self.input_B0)

        self.real_A1_gt_s = Variable(self.all2observed(inp_orna), requires_grad=True)
        self.real_A1_gt = (self.real_A1_gt_s)

        self.fake_B1_gt_emb = self.netE1.forward(self.real_A1_gt)
        self.fake_B1_gt = self.netDE1.forward(self.fake_B1_gt_emb)

        obs_ = torch.cuda.LongTensor(self.obs) if self.opt.gpu_ids else LongTensor(self.obs)

        if self.opt.base_font:
            real_base_gt = index_select(self.real_base, 0, obs_)
            self.real_base_gt = (Variable(real_base_gt.data, requires_grad=False))


    def add_noise_disc(self,real):
        #add noise to the discriminator target labels
        #real: True/False? 
        if self.opt.noisy_disc:
            rand_lbl = random.random()
            if rand_lbl<0.6:
                label = (not real)
            else:
                label = (real)
        else:  
            label = (real)
        return label
            
                
    
    # no backprop gradients
    def test(self):
        self.real_A0 = Variable(self.input_A0, volatile=True)

        if self.opt.conv3d:
            self.real_A0_indep = self.netG_3d.forward(self.real_A0.unsqueeze(2))
            self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2))
        else:
            self.fake_B0 = self.netG.forward(self.real_A0)            

        b,c,m,n = self.fake_B0.size()
        
        #for test time: we need to generate output for all of the glyphs in each input image
        if self.opt.rgb_in:
            self.batch_ = c/self.opt.input_nc_1
        else:
            self.batch_ = c
        self.out_id = range(self.batch_)
        real_A1 = self.Tensor(self.batch_,self.opt.input_nc_1, m,n)

        
        if self.opt.orna:
            inp_orna = self.real_A0
        else:
            inp_orna = self.fake_B0 
        for batch in range(self.batch_):
            if not self.opt.rgb_in and self.opt.rgb_out:
                real_A1[batch,0,:,:] = inp_orna.data[:,self.out_id[batch],:,:]
                real_A1[batch,1,:,:] = inp_orna.data[:,self.out_id[batch],:,:]
                real_A1[batch,2,:,:] = inp_orna.data[:,self.out_id[batch],:,:]
            else:
                real_A1[batch,:,:,:] = inp_orna.data[:,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:]



        self.real_A1 = Variable(real_A1, volatile=True)
    
        fake_B1_emb = self.netE1.forward(self.real_A1.detach())
        self.fake_B1 = self.netDE1.forward(fake_B1_emb)
        
        self.real_B1 = Variable(self.input_B0, volatile=True)


    #get image paths
    def get_image_paths(self):
        return self.image_paths


    def prepare_data(self):
        if self.opt.conditional:
            if self.opt.base_font:
                self.first_pair = self.real_base
                self.first_pair_gt = self.real_base_gt
            else:
                self.first_pair = Variable(self.real_A1.data, requires_grad=False)
                self.first_pair_gt = Variable(self.real_A1_gt.data,requires_grad=False)


    def backward_D1(self):
        b,c,m,n = self.fake_B1.size()
    
        # Fake
        # stop backprop to the generator by detaching fake_B
        label_fake = self.add_noise_disc(False)
        if self.opt.conditional:

            fake_AB1 = self.fake_AB1_pool.query(torch.cat((self.first_pair, self.fake_B1),1))
            self.pred_fake1 = self.netD1.forward(fake_AB1.detach())
            if self.opt.which_model_preNet != 'none':
                #transform the input
                transformed_AB1 = self.preNet_A.forward(fake_AB1.detach())
                self.pred_fake_GL = self.netD1.forward(transformed_AB1)

            self.loss_D1_fake = 0
            self.loss_D1_fake += self.criterionGAN(self.pred_fake1, label_fake) 
            
            if self.opt.which_model_preNet != 'none':
                self.loss_D1_fake += self.criterionGAN(self.pred_fake_GL, label_fake)            
           


        # Real
        label_real = self.add_noise_disc(True)
        if self.opt.conditional:

            real_AB1 = torch.cat((self.first_pair_gt, self.real_B1), 1).detach()                
            self.pred_real1 = self.netD1.forward(real_AB1)

            if self.opt.which_model_preNet != 'none':
                transformed_real_AB1 = self.preNet_A.forward(real_AB1)
                self.pred_real1_GL = self.netD1.forward(transformed_real_AB1)


            self.loss_D1_real = 0
            self.loss_D1_real += self.criterionGAN(self.pred_real1, label_real)    
            if self.opt.which_model_preNet != 'none':                    
                self.loss_D1_real += self.criterionGAN(self.pred_real1_GL, label_real)    

        
        # Combined loss
        self.loss_D1 = (self.loss_D1_fake + self.loss_D1_real) * 0.5
        self.loss_D1.backward()


    def backward_G(self, pass_grad, iter):

        b,c,m,n = self.fake_B0.size()
        if not self.opt.lambda_C or (iter>700):
            self.loss_G_L1 = Variable(torch.zeros(1))

        else:
            weight_val = 10.0

            weights = torch.ones(b,c,m,n).cuda() if self.opt.gpu_ids else torch.ones(b,c,m,n)
            obs_ = torch.cuda.LongTensor(self.obs) if self.opt.gpu_ids else LongTensor(self.obs)
            weights.index_fill_(1,obs_,weight_val)
            weights=Variable(weights, requires_grad=False)

            self.loss_G_L1 = self.criterionL1(weights * self.fake_B0, weights * self.fake_B0_init.detach()) * self.opt.lambda_C
     
            self.loss_G_L1.backward(retain_graph=True)                
            
        self.fake_B0.backward(pass_grad)

    def backward_G1(self,iter):

        # First, G(A) should fake the discriminator
        if self.opt.conditional:

            fake_AB = torch.cat((self.first_pair.detach(), self.fake_B1), 1)
            pred_fake = self.netD1.forward(fake_AB)
            if self.opt.which_model_preNet != 'none':
                #transform the input
                transformed_AB1 = self.preNet_A.forward(fake_AB)
                pred_fake_GL = self.netD1.forward(transformed_AB1)


            self.loss_G1_GAN = 0
            self.loss_G1_GAN += self.criterionGAN(pred_fake, True)            
        
            if self.opt.which_model_preNet != 'none':
                self.loss_G1_GAN += self.criterionGAN(pred_fake_GL, True)            


        self.loss_G1_L1 = self.criterionL1(self.fake_B1_gt, self.real_B1) * self.opt.lambda_A
        fake_B1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.fake_B1,dim=1,keepdim=True)-0.9))
        real_A1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_A1,dim=1,keepdim=True)-0.9))
        self.loss_G1_MSE_rgb2gay = self.criterionMSE(fake_B1_gray, real_A1_gray.detach())* self.opt.lambda_A/3.0


        real_A1_gt_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_A1_gt,dim=1,keepdim=True)-0.9))
        real_B1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_B1,dim=1,keepdim=True)-0.9))


        self.loss_G1_MSE_gt = self.criterionMSE(real_A1_gt_gray, real_B1_gray)* self.opt.lambda_A
        
        # update generator less frequently
        if iter<200:
            rate_gen = 90
        else:
            rate_gen = 60
        

        if (iter%rate_gen)==0:
            self.loss_G1 = self.loss_G1_GAN + self.loss_G1_L1 + self.loss_G1_MSE_gt
            G1_L1_update = True
            G1_GAN_update = True
        else:
            self.loss_G1 = self.loss_G1_L1 + self.loss_G1_MSE_gt
            G1_L1_update = True
            G1_GAN_update = False

        if (iter<200):
            self.loss_G1 += self.loss_G1_MSE_rgb2gay
        else:
            self.loss_G1 += 0.01*self.loss_G1_MSE_rgb2gay

        

        self.loss_G1.backward(retain_graph=True)

        (b,c,m,n) = self.real_A1_s.size()
        self.real_A1_grad = torch.zeros(b,c,m,n).cuda() if self.opt.gpu_ids else torch.zeros(b,c,m,n)

        
        if G1_L1_update:
            for batch in self.obs:
                self.real_A1_grad[batch,:,:,:] = self.real_A1_gt_s.grad.data[self.id_[batch],:,:,:]


    def optimize_parameters(self,iter):
        self.forward0()
        self.forward1(inp_grad=True)
        self.prepare_data()
        
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.zero_grad()
        self.optimizer_D1.zero_grad()
        self.backward_D1()
        self.optimizer_D1.step()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.step()
        self.optimizer_E1.zero_grad()
        self.optimizer_DE1.zero_grad()
        self.backward_G1(iter)
        self.optimizer_DE1.step()
        self.optimizer_E1.step()
        
        self.loss_G_L1 = Variable(torch.zeros(1))


    def optimize_parameters_Stacked(self,iter):
        self.forward0()
        self.forward1(inp_grad=True)
        self.prepare_data()
        
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.zero_grad()

        self.optimizer_D1.zero_grad()
        self.backward_D1()
        self.optimizer_D1.step()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.step()
        self.optimizer_E1.zero_grad()
        self.optimizer_DE1.zero_grad()
        self.backward_G1(iter)
        self.optimizer_DE1.step()
        self.optimizer_E1.step()
        
        b,c,m,n = self.fake_B0.size()
        self.optimizer_G.zero_grad()
        if self.opt.conv3d:
            self.optimizer_G_3d.zero_grad()


        b,c,m,n = self.fake_B0.size()

        fake_B0_grad = torch.zeros(b,c,m,n).cuda() if self.opt.gpu_ids else torch.zeros(b,c,m,n)
        real_A_grad = self.real_A1_grad
        
        for batch in range(self.opt.input_nc):
            if not self.opt.rgb_in and self.opt.rgb_out:
                fake_B0_grad[self.id_[batch], batch,:,:] += torch.mean(real_A_grad[batch,:,:,:],0)*3
            else: 
                #TODO  
                fake_B0_grad[batch, self.obs[batch]*np.array(self.opt.input_nc_1):(self.obs[batch]+1)*np.array(self.opt.input_nc_1),:,:] = real_A_grad[batch,:,:,:]

        self.backward_G(fake_B0_grad, iter)
        self.optimizer_G.step()
        if self.opt.conv3d:
            self.optimizer_G_3d.step()


    def get_current_errors(self):
        return OrderedDict([('G1_GAN', self.loss_G1_GAN.item()),
                ('G1_L1', self.loss_G1_L1.item()),
                ('G1_MSE_gt', self.loss_G1_MSE_gt.item()),
                ('G1_MSE', self.loss_G1_MSE_rgb2gay.item()),
                ('D1_real', self.loss_D1_real.item()),
                ('D1_fake', self.loss_D1_fake.item()),
                ('G_L1', self.loss_G_L1.item())
        ])


    def get_current_visuals(self):
        real_A1 = self.real_A1.data.clone()
        g,c,m,n = real_A1.size()
        fake_B = self.fake_B1.data.clone()
        real_B = self.real_B1.data.clone()
        
        if self.opt.isTrain:
            real_A_all = real_A1
            fake_B_all = fake_B
        else:
            real_A_all = self.Tensor(real_B.size(0),real_B.size(1),real_A1.size(2),real_A1.size(2)*real_A1.size(0))
            fake_B_all = self.Tensor(real_B.size(0),real_B.size(1),real_A1.size(2),fake_B.size(2)*fake_B.size(0))
            for b in range(g):
                real_A_all[:,:,:,self.out_id[b]*m:m*(self.out_id[b]+1)] = real_A1[b,:,:,:]
                fake_B_all[:,:,:,self.out_id[b]*m:m*(self.out_id[b]+1)] = fake_B[b,:,:,:]

        real_A = util.tensor2im(real_A_all)
        fake_B = util.tensor2im(fake_B_all)
        real_B = util.tensor2im(self.real_B1.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])

    def save(self, label):
        if not self.opt.no_Style2Glyph:
            try:
                G_label = str(int(label)+int(self.opt.which_epoch))
            except:
                G_label = label
            if self.opt.conv3d:
                self.save_network(self.netG_3d, 'G_3d', G_label, self.gpu_ids)
            self.save_network(self.netG, 'G', G_label, self.gpu_ids)
        self.save_network(self.netE1, 'E1', label, self.gpu_ids)
        self.save_network(self.netDE1, 'DE1', label, self.gpu_ids)
        self.save_network(self.netD1, 'D1', label, self.gpu_ids)
        if self.opt.which_model_preNet != 'none':
            self.save_network(self.preNet_A, 'PRE_A', label, gpu_ids=self.gpu_ids)


    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        if self.opt.which_model_preNet != 'none':
            for param_group in self.optimizer_preA.param_groups:
                param_group['lr'] = lr
        for param_group in self.optimizer_D1.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_E1.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_DE1.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #13
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:                
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):                    
                        params += [value]
                        finetune_list.add(key.split('.')[0])  
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
Пример #14
0
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'
    
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:                
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):                    
                        params += [value]
                        finetune_list.add(key.split('.')[0])  
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             
        if self.opt.label_nc == 0:
            input_label = label_map.data.cuda()
        else:
            # create one-hot vector for label map 
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
            if self.opt.data_type == 16:
                input_label = input_label.half()

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1)         

        if infer:
            with torch.no_grad():
                input_label = Variable(input_label)
        else:
            input_label = Variable(input_label)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # instance map for feature encoding
        if self.use_features:
            # get precomputed feature maps
            if self.opt.load_features:
                feat_map = Variable(feat_map.data.cuda())
            if self.opt.label_feat:
                inst_map = label_map.cuda()

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:            
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

        # Fake Generation
        # print(f'Fake gen. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB')
        if self.use_features:
            if not self.opt.load_features:
                feat_map = self.netE.forward(real_image, inst_map)                     
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat)

        # Fake Detection and Loss
        # print(f'Fake detection and loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB')
        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss       
        # print(f'Real detection and loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') 
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)     
        # print(f'GAN loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB')   
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)               
        
        # GAN feature matching loss
        # print(f'GAN feature matching loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB')
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
                   
        # VGG feature matching loss
        # print(f'VGG feature matching loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB')
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
        
        # Only return the fake_B image if necessary to save BW
        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

    def inference(self, label, inst, image=None):
        # Encode Inputs        
        image = Variable(image) if image is not None else None
        input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)

        # Fake Generation
        if self.use_features:
            if self.opt.use_encoded_image:
                # encode the real image to get feature map
                feat_map = self.netE.forward(real_image, inst_map)
            else:
                # sample clusters from precomputed features             
                feat_map = self.sample_features(inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label        
           
        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(input_concat)
        else:
            fake_image = self.netG.forward(input_concat)
        return fake_image

    def sample_features(self, inst): 
        # read precomputed feature clusters 
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)        
        features_clustered = np.load(cluster_path, encoding='latin1').item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)                                      
        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
        for i in np.unique(inst_np):    
            label = i if i < 1000 else i//1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0]) 
                                            
                idx = (inst == int(i)).nonzero()
                for k in range(self.opt.feat_num):                                    
                    feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
        if self.opt.data_type==16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):

        with torch.no_grad():
            image = Variable(image.cuda())

        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num+1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i//1000
            idx = (inst == int(i)).nonzero()
            num = idx.size()[0]
            idx = idx[num//2,:]
            val = np.zeros((1, feat_num+1))                        
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].item()
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        if self.opt.data_type==16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())           
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd        
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
    def __init__(self, opt):
	# raise problems using super(),so use BaseModel.__init__(self.opt) instead
        # super(ComboGANModel, self).__init__(opt) 
        BaseModel.__init__(self, opt)
        self.n_domains = opt.n_domains
        self.d_domains = opt.d_domains
        self.batchSize = opt.batchSize
        self.DA, self.DB, self.DC = None, None, None  # classify the domains

        self.real = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)  # images in style 1
        self.real_B = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)  # images in style 2
        self.real_C = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)  # images in style 3
        # images without edges
        self.edge_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        self.edge_B = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        self.edge_C = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        # load/define networks
        self.netG = networks.define_G(opt.netG_framework, opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.netG_n_blocks, opt.netG_n_shared,
                                      self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids)
        if self.isTrain:
            self.netD = networks.define_D(opt.netD_framework, opt.output_nc, opt.ndf, opt.netD_n_layers,
                                          self.d_domains, opt.norm, self.gpu_ids)
            self.classifier = networks.define_classifier(opt.classifier_framework, gpu_ids=self.gpu_ids) # for image classification
            self.vgg = networks.define_VGG(init_weights_=opt.vgg_pretrained_mode, feature_mode_=True, gpu_id_=self.gpu_ids) # using conv4_4 layer
        
        # load model weights
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG, 'G', which_epoch)
            if self.isTrain and not opt.init:
                self.load_network(self.netD, 'D', which_epoch)
                self.load_network(self.classifier, 'A', which_epoch)
            print("load weights of pretrained model successfully")

        # test the function of encoder part
        if opt.encoder_test:
            which_epoch = opt.which_epoch
            self.load_part_network(self.netG, 'G', which_epoch, 0)
            print("load weights of encoder successfully")

	# ======================training initialization==========================================
        if self.isTrain:
            self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)]
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)  # use not opt.no_lsgan
            self.criterionContent = torch.nn.L1Loss()
            self.classGAN = networks.ClassLoss(tensor=self.Tensor)
            # initialize optimizers
            self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
            self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
            self.classifier.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))

            # initialize loss storage
            self.loss_D, self.loss_G_gan = [0]*self.n_domains, [0]*self.n_domains
            # discriminator loss in details
            self.loss_D_real = [0]*self.n_domains
            self.loss_D_fake = [0]*self.n_domains
            self.loss_D_edge = [0]*self.n_domains
            self.loss_D_class_real = [0]*self.n_domains
            self.loss_G_class = [0]*self.n_domains
            self.loss_D_class_edge_fake = [0]*self.n_domains
            # generator loss in details
            self.loss_content = [0]*self.n_domains
            self.loss_content_2 = [0] * self.n_domains
            self.loss_content_3 = [0] * self.n_domains
            # initialize loss multipliers
            self.lambda_con = opt.lambda_content
            self.lambda_cla = opt.lambda_classfication
	    

        print('---------- Networks initialized -------------')
        print(self.netG)
        if self.isTrain:
            print(self.netD)
            print(self.classifier)
        print('-----------------------------------------------')
Пример #16
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'low_freq_A', 'D_B', 'G_B',
            'cycle_B', 'idt_B', 'low_freq_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = [
            'real_A', 'fake_B', 'rec_A', 'blur_real_A', 'blur_fake_B'
        ]
        visual_names_B = [
            'real_B', 'fake_A', 'rec_B', 'blur_real_B', 'blur_fake_A'
        ]
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        self.use_lowfreq_loss = True if opt.lambda_low_freq > 0 else False
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        noise_generator=True)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        noise_generator=False)

        self.netGaussian = SimpleGaussian(gaussian_std=opt.low_pass_std)
        self.netGaussian.apply(weights_init_Gaussian)
        if len(self.gpu_ids) > 0:
            assert (torch.cuda.is_available())
            self.netGaussian.to(self.gpu_ids[0])
            self.netGaussian = torch.nn.DataParallel(
                self.netGaussian, self.gpu_ids)  # multi-GPUs

        if self.isTrain:  # define discriminators

            use_sigmoid = False if (opt.gan_mode == 'lsgan') else True

            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionLowFreq = torch.nn.MSELoss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Пример #17
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # Parameters for WGAN
        self.use_which_gan = opt.use_which_gan  # CycleGAN or CycleWGAN or ICycleWGAN
        self.wgan_clip_upper = opt.wgan_clip_upper
        self.wgan_clip_lower = opt.wgan_clip_lower
        self.wgan_n_critic = opt.wgan_n_critic
        self.wgan_optimizer = opt.wgan_optimizer  # rmsprop

        self.wgan_train_critics = True  # Not sure about this part

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_which_gan=self.use_which_gan,
                use_lsgan=not opt.no_lsgan,
                tensor=self.Tensor)
            # L1 norm
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            if (self.use_which_gan == 'CycleWGAN'):
                if (self.wgan_optimizer == 'rmsprop'):
                    self.optimizer_G = torch.optim.RMSprop(itertools.chain(
                        self.netG_A.parameters(), self.netG_B.parameters()),
                                                           lr=opt.wgan_lrG)
                    self.optimizer_D_A = torch.optim.RMSprop(
                        self.netD_A.parameters(), lr=opt.wgan_lrD)
                    self.optimizer_D_B = torch.optim.RMSprop(
                        self.netD_B.parameters(), lr=opt.wgan_lrD)
            elif (self.use_which_gan == 'CycleGAN'
                  or self.use_which_gan == 'ICycleWGAN'):
                self.optimizer_G = torch.optim.Adam(itertools.chain(
                    self.netG_A.parameters(), self.netG_B.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')
Пример #18
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain and (not opt.no_gan):
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain and (not opt.no_gan):
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            if opt.use_l2:
                self.criterionL1 = torch.nn.MSELoss()
            else:
                self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            if not opt.no_gan:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain and (not opt.no_gan):
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        if not self.opt.no_gan:
            # First, G(A) should fake the discriminator
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            pred_fake = self.netD(fake_AB)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        else:
            self.loss_G_GAN = 0

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        if not self.opt.no_gan:
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        if not self.opt.no_gan:
            return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                                ('G_L1', self.loss_G_L1.data[0]),
                                ('D_real', self.loss_D_real.data[0]),
                                ('D_fake', self.loss_D_fake.data[0])
                                ])
        else:
            return OrderedDict([
                ('G_L1', self.loss_G_L1.data[0])
            ])

    def get_current_visuals(self):
        real_A_img, real_A_prior = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        if self.opt.output_nc == 1:
            fake_B_postprocessed = util.postprocess_parsing(fake_B, self.isTrain)
            fake_B_color = util.paint_color(fake_B_postprocessed)
            real_B_color = util.paint_color(util.postprocess_parsing(real_B, self.isTrain))
        if self.opt.output_nc == 1:
            return OrderedDict([
                ('real_A_img', real_A_img),
                ('real_A_prior', real_A_prior),
                ('fake_B', fake_B),
                ('fake_B_postprocessed', fake_B_postprocessed),
                ('fake_B_color', fake_B_color),
                ('real_B', real_B),
                ('real_B_color', real_B_color)]
            )
        else:
            return OrderedDict([
                ('real_A_img', real_A_img),
                ('real_A_prior', real_A_prior),
                ('fake_B', fake_B),
                ('real_B', real_B)]
            )

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        if not self.opt.no_gan:
            self.save_network(self.netD, 'D', label, self.gpu_ids)
Пример #19
0
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'
    
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none': # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else 3
        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                if self.opt.verbose:
                    print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [{'params':[value],'lr':opt.lr}]
                    else:
                        params += [{'params':[value],'lr':0.0}]                            
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             
        if self.opt.label_nc == 0:
            input_label = label_map.data.cuda()
        else:
            # create one-hot vector for label map 
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
            if self.opt.data_type==16:
                input_label = input_label.half()

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1) 
        input_label = Variable(input_label, requires_grad = not infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # instance map for feature encoding
        if self.use_features:
            # get precomputed feature maps
            if self.opt.load_features:
                feat_map = Variable(feat_map.data.cuda())

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:            
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

        # Fake Generation
        if self.use_features:
            if not self.opt.load_features:
                feat_map = self.netE.forward(real_image, inst_map)                     
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss        
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)        
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)               
        
        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
                   
        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
        
        # Only return the fake_B image if necessary to save BW
        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

    def inference(self, label, inst):
        # Encode Inputs        
        input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)

        # Fake Generation
        if self.use_features:       
            # sample clusters from precomputed features             
            feat_map = self.sample_features(inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label                
        fake_image = self.netG.forward(input_concat)
        return fake_image

    def sample_features(self, inst): 
        # read precomputed feature clusters 
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)        
        features_clustered = np.load(cluster_path).item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)                                      
        feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])                   
        for i in np.unique(inst_np):    
            label = i if i < 1000 else i//1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0]) 
                                            
                idx = (inst == i).nonzero()
                for k in range(self.opt.feat_num):                                    
                    feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
        if self.opt.data_type==16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.cuda(), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num+1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i//1000
            idx = (inst == i).nonzero()
            num = idx.size()[0]
            idx = idx[num//2,:]
            val = np.zeros((1, feat_num+1))                        
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]            
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        if self.opt.data_type==16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())           
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd        
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #20
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # define tensors
        self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc,
                                   opt.fineSize, opt.fineSize)
        self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)

        self.input_base = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)


        # load/define networks
        if self.opt.conv3d:
            # one layer for considering a conv filter for each of the 26 channels
            self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids)

        # Generator of the GlyphNet
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)

        
        #Generator of the OrnaNet as an Encoder and a Decoder
        self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids)
        
        self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1, opt.ngf,
                                    opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids)
                            

        if self.opt.conditional:
            # not applicable for non-conditional case
            use_sigmoid = opt.no_lsgan
            if opt.which_model_preNet != 'none':
                self.preNet_A = networks.define_preNet(self.opt.input_nc_1+self.opt.output_nc_1, self.opt.input_nc_1+self.opt.output_nc_1, which_model_preNet=opt.which_model_preNet,norm=opt.norm, gpu_ids=self.gpu_ids)

            nif = opt.input_nc_1+opt.output_nc_1

            
            netD_norm = opt.norm

            self.netD1 = networks.define_D(nif, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, netD_norm, use_sigmoid, True, self.gpu_ids)



        if self.isTrain:
            if self.opt.conv3d:
                 self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)

            self.load_network(self.netG, 'G', opt.which_epoch)

            if self.opt.print_weights:
                for key in self.netE1.state_dict().keys():
                    print key, 'random_init, mean,std:', torch.mean(self.netE1.state_dict()[key]),torch.std(self.netE1.state_dict()[key])
                for key in self.netDE1.state_dict().keys():
                    print key, 'random_init, mean,std:', torch.mean(self.netDE1.state_dict()[key]),torch.std(self.netDE1.state_dict()[key])


        if not self.isTrain:
            print "Load generators from their pretrained models..."
            if opt.no_Style2Glyph:
                if self.opt.conv3d:
                     self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)
            else:
                if self.opt.conv3d:
                     self.load_network(self.netG_3d, 'G_3d', str(int(opt.which_epoch)+int(opt.which_epoch1)))
                self.load_network(self.netG, 'G', str(int(opt.which_epoch)+int(opt.which_epoch1)))
                self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1)))
                self.load_network(self.netDE1, 'DE1', str(int(opt.which_epoch1)))
                self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1)))
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)


        if self.isTrain:
            if opt.continue_train:
                print "Load StyleNet from its pretrained model..."
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)


        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
        if self.isTrain:
            self.fake_AB1_pool = ImagePool(opt.pool_size)

            self.old_lr = opt.lr
            # define loss functions
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()


            # initialize optimizers
            if self.opt.conv3d:
                 self.optimizer_G_3d = torch.optim.Adam(self.netG_3d.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.which_model_preNet != 'none':
                self.optimizer_preA = torch.optim.Adam(self.preNet_A.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))

                                            
            self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))


            print('---------- Networks initialized -------------')
            if self.opt.conv3d:
                networks.print_network(self.netG_3d)
            networks.print_network(self.netG)
            networks.print_network(self.netE1)
            networks.print_network(self.netDE1)
            if opt.which_model_preNet != 'none':
                networks.print_network(self.preNet_A)

            networks.print_network(self.netD1)
            print('-----------------------------------------------')

            self.initial = True
Пример #21
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # Parameters for WGAN
        self.use_which_gan = opt.use_which_gan  # CycleGAN or CycleWGAN or ICycleWGAN
        self.wgan_clip_upper = opt.wgan_clip_upper
        self.wgan_clip_lower = opt.wgan_clip_lower
        self.wgan_n_critic = opt.wgan_n_critic
        self.wgan_optimizer = opt.wgan_optimizer  # rmsprop

        self.wgan_train_critics = True  # Not sure about this part

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_which_gan=self.use_which_gan,
                use_lsgan=not opt.no_lsgan,
                tensor=self.Tensor)
            # L1 norm
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            if (self.use_which_gan == 'CycleWGAN'):
                if (self.wgan_optimizer == 'rmsprop'):
                    self.optimizer_G = torch.optim.RMSprop(itertools.chain(
                        self.netG_A.parameters(), self.netG_B.parameters()),
                                                           lr=opt.wgan_lrG)
                    self.optimizer_D_A = torch.optim.RMSprop(
                        self.netD_A.parameters(), lr=opt.wgan_lrD)
                    self.optimizer_D_B = torch.optim.RMSprop(
                        self.netD_B.parameters(), lr=opt.wgan_lrD)
            elif (self.use_which_gan == 'CycleGAN'
                  or self.use_which_gan == 'ICycleWGAN'):
                self.optimizer_G = torch.optim.Adam(itertools.chain(
                    self.netG_A.parameters(), self.netG_B.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_wasserstein(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        pred_fake = netD.forward(fake)
        loss_D = self.criterionGAN(pred_fake, pred_real, generator_loss=False)
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    # Backward for discriminator, wgan
    def backward_wgan_D(self, critic_iter):
        # D_A
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_wasserstein(self.netD_A, self.real_B,
                                                    fake_B)
        # D_B
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_wasserstein(self.netD_B, self.real_A,
                                                    fake_A)

        loss_D = (self.loss_D_A + self.loss_D_B) * 0.5
        loss_D.backward(retain_variables=True)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        # Save all the data from the previous tensors
        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def backward_wgan_G(self, do_backward=True):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # Wasserstein-GAN loss
        # G_A(A)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.loss_G_A = self.criterionGAN(self.fake_B, generator_loss=True)

        # G_B(B)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.loss_G_B = self.criterionGAN(self.fake_A, generator_loss=True)

        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A

        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

        # Combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

        if do_backward:
            # Backprop
            self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        if (self.use_which_gan == 'CycleGAN'):

            # G_A and G_B
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()

            # D_A
            self.optimizer_D_A.zero_grad()
            self.backward_D_A()
            self.optimizer_D_A.step()
            # D_B
            self.optimizer_D_B.zero_grad()
            self.backward_D_B()
            self.optimizer_D_B.step()

        # The changes here are that we need to add a bound for weights in the range [-c, c]
        elif (self.use_which_gan == 'CycleWGAN'):

            # G_A and G_B
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()

            for t in range(self.wgan_n_critic):
                # D_A
                self.optimizer_D_A.zero_grad()
                self.backward_D_A()
                self.optimizer_D_A.step()
                # clip
                for p in self.netD_A.parameters():
                    p.data.clamp_(self.wgan_clip_lower, self.wgan_clip_upper)

                # D_B
                self.optimizer_D_B.zero_grad()
                self.backward_D_B()
                self.optimizer_D_B.step()
                # clip
                for p in self.netD_B.parameters():
                    p.data.clamp_(self.wgan_clip_lower, self.wgan_clip_upper)

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_B),
                                  ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
Пример #22
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
        self.count = 0
        ##### define networks
        # Generator network
        netG_input_nc = input_nc
        # Main Generator
        with torch.no_grad():
            self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()
            self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval()
            self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval()
            self.G = networks.define_Refine(24, 3, self.gpu_ids).eval()

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.BCE = torch.nn.BCEWithLogitsLoss()

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            netB_input_nc = opt.output_nc * 2
            # self.D1 = self.get_D(17, opt)
            # self.D2 = self.get_D(4, opt)
            # self.D3=self.get_D(7+3,opt)
            # self.D = self.get_D(20, opt)
            # self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids)

        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.Unet, 'U', opt.which_epoch, pretrained_path)
            self.load_network(self.G1, 'G1', opt.which_epoch, pretrained_path)
            self.load_network(self.G2, 'G2', opt.which_epoch, pretrained_path)
            self.load_network(self.G, 'G', opt.which_epoch, pretrained_path)
        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss)

            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
            self.criterionStyle = networks.StyleLoss(self.gpu_ids)
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG',
                                               'D_real', 'D_fake')
            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print(
                    '------------- Only training the local enhancer ork (for %d epochs) ------------'
                    % opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))
Пример #23
0
class Pix2PixModel(BaseModel):
    '''
    * @name: name
    * @description: return the name of this model
    * @return: the name of this model
    '''
    def name(self):
        return 'Pix2PixModel'

    '''
    * @name: initialize
    * @description: initialize the pix2pix model with the parameter set
    * @param opt: the configured parameter set
    '''

    def initialize(self, opt):
        #initialize the base class with given parameter set opt
        BaseModel.initialize(self, opt)
        #get the type of the program(train or test)
        self.isTrain = opt.isTrain

        # load/define Generator
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type)
        #define the Discriminator
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)
        #deploy generator to device
        self.netG = self.netG.to(self.device)

        #deploy discriminator to device
        if self.isTrain:
            self.netD = self.netD.to(self.device)

        #if the program is for training
        if self.isTrain:
            #set the size of image buffer that stores previously generated images
            self.fake_AB_pool = ImagePool(opt.pool_size)

            #set initial learning rate for adam
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 device=self.device)

            self.criterionL1 = torch.nn.L1Loss().to(self.device)

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            #define the optimizer for generator
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            #define the optimizer for discriminator
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            #save the optimizers
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            #save schedulers
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    #get the input data set
    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.input_A = input['A' if AtoB else 'B'].to(self.device)
        self.input_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
        if 'w' in input:
            self.input_w = input['w']
        if 'h' in input:
            self.input_h = input['h']

    #the forward function
    def forward(self):
        #get the input image
        self.real_A = self.input_A
        #generate the fake image by generator
        self.fake_B = self.netG(self.real_A)
        #get the groudtruth image
        self.real_B = self.input_B

    # no backprop gradients
    def test(self):
        with torch.no_grad():
            self.forward()

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    #backpropagate function for discriminator
    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1).detach())
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        n = self.real_B.shape[1]
        loss_D_real_set = torch.empty(n, device=self.device)
        for i in range(n):
            sel_B = self.real_B[:, i, :, :].unsqueeze(1)
            real_AB = torch.cat((self.real_A, sel_B), 1)
            pred_real = self.netD(real_AB)
            loss_D_real_set[i] = self.criterionGAN(pred_real, True)

        #get the average all input groundtruth
        self.loss_D_real = torch.mean(loss_D_real_set)

        # Combined loss
        self.loss_D = (self.loss_D_fake +
                       self.loss_D_real) * 0.5 * self.opt.lambda_G

        self.loss_D.backward()

    #backpropagate function for generator
    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake,
                                            True) * self.opt.lambda_G

        # Second, G(A) = B
        n = self.real_B.shape[1]
        fake_B_expand = self.fake_B.expand(-1, n, -1, -1)
        L1 = torch.abs(fake_B_expand - self.real_B)
        L1 = L1.view(-1, n, self.real_B.shape[2] * self.real_B.shape[3])
        L1 = torch.mean(L1, 2)
        min_L1, min_idx = torch.min(L1, 1)
        self.loss_G_L1 = torch.mean(min_L1) * self.opt.lambda_A
        self.min_idx = min_idx

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        #train discriminator
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()
        #train the generator
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
                            ('G_L1', self.loss_G_L1.item()),
                            ('D_real', self.loss_D_real.item()),
                            ('D_fake', self.loss_D_fake.item())])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.detach())
        fake_B = util.tensor2im(self.fake_B.detach())
        if self.isTrain:
            sel_B = self.real_B[:, self.min_idx[0], :, :]
        else:
            sel_B = self.real_B[:, 0, :, :]
        real_B = util.tensor2im(sel_B.unsqueeze(1).detach())
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label)
        self.save_network(self.netD, 'D', label)

    def write_image(self, out_dir):
        image_numpy = self.fake_B.detach()[0][0].cpu().float().numpy()
        image_numpy = (image_numpy + 1) / 2.0 * 255.0
        image_pil = Image.fromarray(image_numpy.astype(np.uint8))
        image_pil = image_pil.resize((self.input_w[0], self.input_h[0]),
                                     Image.BICUBIC)
        name, _ = os.path.splitext(os.path.basename(self.image_paths[0]))
        out_path = os.path.join(out_dir, name + self.opt.suffix + '.png')
        image_pil.save(out_path)
Пример #24
0
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)

        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [
                l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real,
                                     d_fake), flags) if f
            ]

        return loss_filter

    def get_G(self, in_C, out_c, n_blocks, opt, L=1, S=1):
        return networks.define_G(in_C,
                                 out_c,
                                 opt.ngf,
                                 opt.netG,
                                 L,
                                 S,
                                 opt.n_downsample_global,
                                 n_blocks,
                                 opt.n_local_enhancers,
                                 opt.n_blocks_local,
                                 opt.norm,
                                 gpu_ids=self.gpu_ids)

    def get_D(self, inc, opt):
        netD = networks.define_D(inc,
                                 opt.ndf,
                                 opt.n_layers_D,
                                 opt.norm,
                                 opt.no_lsgan,
                                 opt.num_D,
                                 not opt.no_ganFeat_loss,
                                 gpu_ids=self.gpu_ids)
        return netD

    def cross_entropy2d(self, input, target, weight=None, size_average=True):
        n, c, h, w = input.size()
        nt, ht, wt = target.size()

        # Handle inconsistent size between input and target
        if h != ht or w != wt:
            input = F.interpolate(input,
                                  size=(ht, wt),
                                  mode="bilinear",
                                  align_corners=True)

        input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
        target = target.view(-1)
        loss = F.cross_entropy(input,
                               target,
                               weight=weight,
                               size_average=size_average,
                               ignore_index=250)

        return loss

    def ger_average_color(self, mask, arms):
        color = torch.zeros(arms.shape).cuda()
        for i in range(arms.shape[0]):
            count = len(torch.nonzero(mask[i, :, :, :]))
            if count < 10:
                color[i, 0, :, :] = 0
                color[i, 1, :, :] = 0
                color[i, 2, :, :] = 0

            else:
                color[i, 0, :, :] = arms[i, 0, :, :].sum() / count
                color[i, 1, :, :] = arms[i, 1, :, :].sum() / count
                color[i, 2, :, :] = arms[i, 2, :, :].sum() / count
        return color

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
        self.count = 0
        ##### define networks
        # Generator network
        netG_input_nc = input_nc
        # Main Generator
        with torch.no_grad():
            self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()
            self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval()
            self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval()
            self.G = networks.define_Refine(24, 3, self.gpu_ids).eval()

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.BCE = torch.nn.BCEWithLogitsLoss()

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            netB_input_nc = opt.output_nc * 2
            # self.D1 = self.get_D(17, opt)
            # self.D2 = self.get_D(4, opt)
            # self.D3=self.get_D(7+3,opt)
            # self.D = self.get_D(20, opt)
            # self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids)

        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.Unet, 'U', opt.which_epoch, pretrained_path)
            self.load_network(self.G1, 'G1', opt.which_epoch, pretrained_path)
            self.load_network(self.G2, 'G2', opt.which_epoch, pretrained_path)
            self.load_network(self.G, 'G', opt.which_epoch, pretrained_path)
        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss)

            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
            self.criterionStyle = networks.StyleLoss(self.gpu_ids)
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG',
                                               'D_real', 'D_fake')
            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print(
                    '------------- Only training the local enhancer ork (for %d epochs) ------------'
                    % opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))

    def encode_input(self, label_map, clothes_mask, all_clothes_label):

        size = label_map.size()
        oneHot_size = (size[0], 14, size[2], size[3])
        input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1,
                                           label_map.data.long().cuda(), 1.0)

        masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        masked_label = masked_label.scatter_(
            1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0)

        c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        c_label = c_label.scatter_(1,
                                   all_clothes_label.data.long().cuda(), 1.0)

        input_label = Variable(input_label)

        return input_label, masked_label, c_label

    def encode_input_test(self,
                          label_map,
                          label_map_ref,
                          real_image_ref,
                          infer=False):

        if self.opt.label_nc == 0:
            input_label = label_map.data.cuda()
            input_label_ref = label_map_ref.data.cuda()
        else:
            # create one-hot vector for label map
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(
                torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1,
                                               label_map.data.long().cuda(),
                                               1.0)
            input_label_ref = torch.cuda.FloatTensor(
                torch.Size(oneHot_size)).zero_()
            input_label_ref = input_label_ref.scatter_(
                1,
                label_map_ref.data.long().cuda(), 1.0)
            if self.opt.data_type == 16:
                input_label = input_label.half()
                input_label_ref = input_label_ref.half()

        input_label = Variable(input_label, volatile=infer)
        input_label_ref = Variable(input_label_ref, volatile=infer)
        real_image_ref = Variable(real_image_ref.data.cuda())

        return input_label, input_label_ref, real_image_ref

    def discriminate(self, netD, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return netD.forward(fake_query)
        else:
            return netD.forward(input_concat)

    def gen_noise(self, shape):
        noise = np.zeros(shape, dtype=np.uint8)
        ### noise
        noise = cv2.randn(noise, 0, 255)
        noise = np.asarray(noise / 255, dtype=np.uint8)
        noise = torch.tensor(noise, dtype=torch.float32)
        return noise.cuda()

    def multi_scale_blend(self, fake_img, fake_c, mask, number=4):
        alpha = [0, 0.1, 0.3, 0.6, 0.9]
        smaller = mask
        out = 0
        for i in range(1, number + 1):
            bigger = smaller
            smaller = morpho(smaller, 2, False)
            mid = bigger - smaller
            out += mid * (alpha[i] * fake_c + (1 - alpha[i]) * fake_img)
        out += smaller * fake_c
        out += (1 - mask) * fake_img
        return out

    def forward(self, label, pre_clothes_mask, img_fore, clothes_mask, clothes,
                all_clothes_label, real_image, pose, grid, mask_fore):
        # Encode Inputs
        input_label, masked_label, all_clothes_label = self.encode_input(
            label, clothes_mask, all_clothes_label)
        arm1_mask = torch.FloatTensor(
            (label.cpu().numpy() == 11).astype(np.float)).cuda()
        arm2_mask = torch.FloatTensor(
            (label.cpu().numpy() == 13).astype(np.float)).cuda()
        pre_clothes_mask = torch.FloatTensor(
            (pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(
                np.float)).cuda()
        clothes = clothes * pre_clothes_mask

        shape = pre_clothes_mask.shape

        G1_in = torch.cat([
            pre_clothes_mask, clothes, all_clothes_label, pose,
            self.gen_noise(shape)
        ],
                          dim=1)
        arm_label = self.G1.refine(G1_in)

        arm_label = self.sigmoid(arm_label)
        CE_loss = self.cross_entropy2d(arm_label,
                                       (label * (1 - clothes_mask)).transpose(
                                           0, 1)[0].long()) * 10

        armlabel_map = generate_discrete_label(arm_label.detach(), 14, False)
        dis_label = generate_discrete_label(arm_label.detach(), 14)
        G2_in = torch.cat([
            pre_clothes_mask, clothes, dis_label, pose,
            self.gen_noise(shape)
        ], 1)
        fake_cl = self.G2.refine(G2_in)
        fake_cl = self.sigmoid(fake_cl)
        CE_loss += self.BCE(fake_cl, clothes_mask) * 10

        fake_cl_dis = torch.FloatTensor(
            (fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
        fake_cl_dis = morpho(fake_cl_dis, 1, True)

        new_arm1_mask = torch.FloatTensor(
            (armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda()
        new_arm2_mask = torch.FloatTensor(
            (armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda()
        fake_cl_dis = fake_cl_dis * (1 - new_arm1_mask) * (1 - new_arm2_mask)
        fake_cl_dis *= mask_fore

        arm1_occ = clothes_mask * new_arm1_mask
        arm2_occ = clothes_mask * new_arm2_mask
        bigger_arm1_occ = morpho(arm1_occ, 10)
        bigger_arm2_occ = morpho(arm2_occ, 10)
        arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask
        arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask
        armlabel_map *= (1 - new_arm1_mask)
        armlabel_map *= (1 - new_arm2_mask)
        armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11
        armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13
        armlabel_map *= (1 - fake_cl_dis)
        dis_label = encode(armlabel_map, armlabel_map.shape)

        fake_c, warped, warped_mask, warped_grid = self.Unet(
            clothes, fake_cl_dis, pre_clothes_mask, grid)
        mask = fake_c[:, 3, :, :]
        mask = self.sigmoid(mask) * fake_cl_dis
        fake_c = self.tanh(fake_c[:, 0:3, :, :])
        fake_c = fake_c * (1 - mask) + mask * warped
        skin_color = self.ger_average_color(
            (arm1_mask + arm2_mask - arm2_mask * arm1_mask),
            (arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image)
        occlude = (1 - bigger_arm1_occ *
                   (arm2_mask + arm1_mask + clothes_mask)) * (
                       1 - bigger_arm2_occ *
                       (arm2_mask + arm1_mask + clothes_mask))
        img_hole_hand = img_fore * (1 - clothes_mask) * occlude * (1 -
                                                                   fake_cl_dis)

        G_in = torch.cat([
            img_hole_hand, dis_label, fake_c, skin_color,
            self.gen_noise(shape)
        ], 1)
        fake_image = self.G.refine(G_in.detach())
        fake_image = self.tanh(fake_image)

        loss_D_fake = 0
        loss_D_real = 0
        loss_G_GAN = 0
        loss_G_VGG = 0

        L1_loss = 0

        style_loss = L1_loss

        return [
            self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real,
                             loss_D_fake), fake_image, clothes, arm_label,
            L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid
        ]

    def inference(self, label, label_ref, image_ref):

        # Encode Inputs
        image_ref = Variable(image_ref)
        input_label, input_label_ref, real_image_ref = self.encode_input_test(
            Variable(label), Variable(label_ref), image_ref, infer=True)

        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(input_label, input_label_ref,
                                               real_image_ref)
        else:
            fake_image = self.netG.forward(input_label, input_label_ref,
                                           real_image_ref)
        return fake_image

    def save(self, which_epoch):
        # self.save_network(self.Unet, 'U', which_epoch, self.gpu_ids)
        # self.save_network(self.G, 'G', which_epoch, self.gpu_ids)
        # self.save_network(self.G1, 'G1', which_epoch, self.gpu_ids)
        # self.save_network(self.G2, 'G2', which_epoch, self.gpu_ids)
        # # self.save_network(self.G3, 'G3', which_epoch, self.gpu_ids)
        # self.save_network(self.D, 'D', which_epoch, self.gpu_ids)
        # self.save_network(self.D1, 'D1', which_epoch, self.gpu_ids)
        # self.save_network(self.D2, 'D2', which_epoch, self.gpu_ids)
        # self.save_network(self.D3, 'D3', which_epoch, self.gpu_ids)

        pass

        # self.save_network(self.netB, 'B', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print(
                '------------ Now also finetuning global generator -----------'
            )

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #25
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,  
                                   opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0])
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, 
                                   opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0])

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)

        # If this is training phase
        if self.isTrain:
            use_sigmoid = opt.no_lsgan # do not use least square GAN by default
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)

        # If this is non-training phase/continue training phase
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            # build up so called history pool
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gpu_ids=opt.gpu_ids)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            if opt.use_prcp:
                self.criterionPrcp = networks.PrcpLoss(opt.weight_path, opt.bias_path, opt.perceptual_level, tensor=self.Tensor, gpu_ids=opt.gpu_ids)

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            print('-----------------------------------------------')
    def initialize(self, opt):
        self.opt = opt
        self.isTrain = opt.isTrain
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks
        # Generator network
        netG_input_nc = input_nc
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num

        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
                                      opt.n_blocks_local, opt.norm)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
                                          opt.num_D, not opt.no_ganFeat_loss)

        ### Encoder network
        # if self.gen_features:
        #     self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
        #                                   opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        # if not self.isTrain or opt.continue_train or opt.load_pretrain:
        #     pretrained_path = '' if not self.isTrain else opt.load_pretrain
        #     self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
        #     if self.isTrain:
        #         self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
        #     if self.gen_features:
        #         self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            # self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)

            # self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionFeat = K.losses.MeanAbsoluteError()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss()


            # Names so we can breakout loss
            self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake']

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                # params_dict = dict(self.netG.named_parameters())
                # params = []
                # for key, value in params_dict.items():
                #     if key.startswith('model' + str(opt.n_local_enhancers)):
                #         params += [value]
                #         finetune_list.add(key.split('.')[0])
                # print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                # print('The layers that are finetuned are ', sorted(finetune_list))
            else:
                pass
Пример #27
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.batchSize = opt.batchSize
        self.fineSize = opt.fineSize

        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize, opt.fineSize)

        if self.opt.rise_sobelLoss:
            self.sobelLambda = 0
        else:
            self.sobelLambda = self.opt.lambda_sobel

        # load/define networks

        which_netG = opt.which_model_netG
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      which_netG, opt.norm, opt.use_dropout,
                                      self.gpu_ids)
        if self.isTrain:

            self.D_channel = opt.input_nc + opt.output_nc
            use_sigmoid = opt.no_lsgan

            self.netD = networks.define_D(self.D_channel, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid, self.gpu_ids)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)
        if not self.isTrain:
            self.netG.eval()

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            if self.opt.labelSmooth:
                self.criterionGAN = networks.GANLoss_smooth(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG)
            networks.print_network(self.netD)
            print('-----------------------------------------------')
Пример #28
0
class ReCycleGANModel(BaseModel):
    def name(self):
        return 'ReCycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A0 = self.Tensor(nb, opt.input_nc, size, size)
        self.input_A1 = self.Tensor(nb, opt.input_nc, size, size)
        self.input_A2 = self.Tensor(nb, opt.input_nc, size, size)

        self.input_B0 = self.Tensor(nb, opt.output_nc, size, size)
        self.input_B1 = self.Tensor(nb, opt.output_nc, size, size)
        self.input_B2 = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        self.which_model_netP = opt.which_model_netP
        if opt.which_model_netP == 'prediction':
            self.netP_A = networks.define_G(opt.input_nc, opt.input_nc,
                                            opt.npf, opt.which_model_netP,
                                            opt.norm, not opt.no_dropout,
                                            opt.init_type, self.gpu_ids)
            self.netP_B = networks.define_G(opt.output_nc, opt.output_nc,
                                            opt.npf, opt.which_model_netP,
                                            opt.norm, not opt.no_dropout,
                                            opt.init_type, self.gpu_ids)
        else:
            self.netP_A = networks.define_G(2 * opt.input_nc, opt.input_nc,
                                            opt.ngf, 'unet_128', opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            self.gpu_ids)
            self.netP_B = networks.define_G(2 * opt.output_nc, opt.output_nc,
                                            opt.ngf, 'unet_128', opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.netP_A, 'P_A', which_epoch)
            self.load_network(self.netP_B, 'P_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(
                itertools.chain(self.netG_A.parameters(),
                                self.netG_B.parameters(),
                                self.netP_A.parameters(),
                                self.netP_B.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        networks.print_network(self.netP_A)
        networks.print_network(self.netP_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A0 = input['A0']
        input_A1 = input['A1']
        input_A2 = input['A2']

        input_B0 = input['B0']
        input_B1 = input['B1']
        input_B2 = input['B2']

        self.input_A0.resize_(input_A0.size()).copy_(input_A0)
        self.input_A1.resize_(input_A1.size()).copy_(input_A1)
        self.input_A2.resize_(input_A2.size()).copy_(input_A2)

        self.input_B0.resize_(input_B0.size()).copy_(input_B0)
        self.input_B1.resize_(input_B1.size()).copy_(input_B1)
        self.input_B2.resize_(input_B2.size()).copy_(input_B2)

        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A0 = Variable(self.input_A0)
        self.real_A1 = Variable(self.input_A1)
        self.real_A2 = Variable(self.input_A2)

        self.real_B0 = Variable(self.input_B0)
        self.real_B1 = Variable(self.input_B1)
        self.real_B2 = Variable(self.input_B2)

    def test(self):
        real_A0 = Variable(self.input_A0, volatile=True)
        real_A1 = Variable(self.input_A1, volatile=True)

        fake_B0 = self.netG_A(real_A0)
        fake_B1 = self.netG_A(real_A1)
        # fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1),1))
        if self.which_model_netP == 'prediction':
            fake_B2 = self.netP_B(fake_B0, fake_B1)
        else:
            fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1), 1))

        self.rec_A = self.netG_B(fake_B2)
        self.fake_B0 = fake_B0
        self.fake_B1 = fake_B1
        self.fake_B2 = fake_B2

        real_B0 = Variable(self.input_B0, volatile=True)
        real_B1 = Variable(self.input_B1, volatile=True)

        fake_A0 = self.netG_B(real_B0)
        fake_A1 = self.netG_B(real_B1)
        # fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1),1))
        if self.which_model_netP == 'prediction':
            fake_A2 = self.netP_A(fake_A0, fake_A1)
        else:
            fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1), 1))

        self.rec_B = self.netG_A(fake_A2)
        self.fake_A0 = fake_A0
        self.fake_A1 = fake_A1
        self.fake_A2 = fake_A2

        # pred_A2 = self.netP_A(torch.cat((real_A0, real_A1),1))
        if self.which_model_netP == 'prediction':
            pred_A2 = self.netP_A(real_A0, real_A1)
        else:
            pred_A2 = self.netP_A(torch.cat((real_A0, real_A1), 1))

        self.pred_A2 = pred_A2

        # pred_B2 = self.netP_B(torch.cat((real_B0, real_B1),1))
        if self.which_model_netP == 'prediction':
            pred_B2 = self.netP_B(real_B0, real_B1)
        else:
            pred_B2 = self.netP_B(torch.cat((real_B0, real_B1), 1))

        self.pred_B2 = pred_B2

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B0 = self.fake_B_pool.query(self.fake_B0)
        loss_D_A0 = self.backward_D_basic(self.netD_A, self.real_B0, fake_B0)

        fake_B1 = self.fake_B_pool.query(self.fake_B1)
        loss_D_A1 = self.backward_D_basic(self.netD_A, self.real_B1, fake_B1)

        fake_B2 = self.fake_B_pool.query(self.fake_B2)
        loss_D_A2 = self.backward_D_basic(self.netD_A, self.real_B2, fake_B2)

        pred_B = self.fake_B_pool.query(self.pred_B2)
        loss_D_A3 = self.backward_D_basic(self.netD_A, self.real_B2, pred_B)

        self.loss_D_A = loss_D_A0 + loss_D_A1 + loss_D_A2 + loss_D_A3

    def backward_D_B(self):
        fake_A0 = self.fake_A_pool.query(self.fake_A0)
        loss_D_B0 = self.backward_D_basic(self.netD_B, self.real_A0, fake_A0)

        fake_A1 = self.fake_A_pool.query(self.fake_A1)
        loss_D_B1 = self.backward_D_basic(self.netD_B, self.real_A1, fake_A1)

        fake_A2 = self.fake_A_pool.query(self.fake_A2)
        loss_D_B2 = self.backward_D_basic(self.netD_B, self.real_A2, fake_A2)

        pred_A = self.fake_A_pool.query(self.pred_A2)
        loss_D_B3 = self.backward_D_basic(self.netD_B, self.real_A2, pred_A)

        self.loss_D_B = loss_D_B0 + loss_D_B1 + loss_D_B2 + loss_D_B3

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A0 = self.netG_A(self.real_B0)
            idt_A1 = self.netG_A(self.real_B1)
            loss_idt_A = (self.criterionIdt(idt_A0,
                                            self.real_B0) + self.criterionIdt(
                idt_A1, self.real_B1)) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B0 = self.netG_B(self.real_A0)
            idt_B1 = self.netG_B(self.real_A1)
            loss_idt_B = (self.criterionIdt(idt_B0,
                                            self.real_A0) + self.criterionIdt(
                idt_B1, self.real_A1)) * lambda_A * lambda_idt

            self.idt_A = idt_A0
            self.idt_B = idt_B0
            self.loss_idt_A = loss_idt_A
            self.loss_idt_B = loss_idt_B

        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B0 = self.netG_A(self.real_A0)
        pred_fake = self.netD_A(fake_B0)
        loss_G_A0 = self.criterionGAN(pred_fake, True)

        fake_B1 = self.netG_A(self.real_A1)
        pred_fake = self.netD_A(fake_B1)
        loss_G_A1 = self.criterionGAN(pred_fake, True)

        # fake_B2 = self.netP_B(torch.cat((fake_B0,fake_B1),1))
        if self.which_model_netP == 'prediction':
            fake_B2 = self.netP_B(fake_B0, fake_B1)
        else:
            fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1), 1))

        pred_fake = self.netD_A(fake_B2)
        loss_G_A2 = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A0 = self.netG_B(self.real_B0)
        pred_fake = self.netD_B(fake_A0)
        loss_G_B0 = self.criterionGAN(pred_fake, True)

        fake_A1 = self.netG_B(self.real_B1)
        pred_fake = self.netD_B(fake_A1)
        loss_G_B1 = self.criterionGAN(pred_fake, True)

        # fake_A2 = self.netP_A(torch.cat((fake_A0,fake_A1),1))
        if self.which_model_netP == 'prediction':
            fake_A2 = self.netP_A(fake_A0, fake_A1)
        else:
            fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1), 1))

        pred_fake = self.netD_B(fake_A2)
        loss_G_B2 = self.criterionGAN(pred_fake, True)

        # prediction loss --
        # pred_A2 = self.netP_A(torch.cat((self.real_A0, self.real_A1),1))
        if self.which_model_netP == 'prediction':
            pred_A2 = self.netP_A(self.real_A0, self.real_A1)
        else:
            pred_A2 = self.netP_A(torch.cat((self.real_A0, self.real_A1), 1))

        loss_pred_A = self.criterionCycle(pred_A2, self.real_A2) * lambda_A

        # pred_B2 = self.netP_B(torch.cat((self.real_B0, self.real_B1),1))
        if self.which_model_netP == 'prediction':
            pred_B2 = self.netP_B(self.real_B0, self.real_B1)
        else:
            pred_B2 = self.netP_B(torch.cat((self.real_B0, self.real_B1), 1))

        loss_pred_B = self.criterionCycle(pred_B2, self.real_B2) * lambda_B

        # Forward recycle loss
        rec_A = self.netG_B(fake_B2)
        loss_recycle_A = self.criterionCycle(rec_A, self.real_A2) * lambda_A

        # Backward recycle loss
        rec_B = self.netG_A(fake_A2)
        loss_recycle_B = self.criterionCycle(rec_B, self.real_B2) * lambda_B

        # Fwd cycle loss
        rec_A0 = self.netG_B(fake_B0)
        loss_cycle_A0 = self.criterionCycle(rec_A0, self.real_A0) * lambda_A

        rec_A1 = self.netG_B(fake_B1)
        loss_cycle_A1 = self.criterionCycle(rec_A1, self.real_A1) * lambda_A

        rec_B0 = self.netG_A(fake_A0)
        loss_cycle_B0 = self.criterionCycle(rec_B0, self.real_B0) * lambda_B

        rec_B1 = self.netG_A(fake_A1)
        loss_cycle_B1 = self.criterionCycle(rec_B1, self.real_B1) * lambda_B

        # combined loss
        loss_G = loss_G_A0 + loss_G_A1 + loss_G_A2 + loss_G_B0 + loss_G_B1 + loss_G_B2 + loss_recycle_A + loss_recycle_B + loss_pred_A + loss_pred_B + loss_idt_A + loss_idt_B + loss_cycle_A0 + loss_cycle_A1 + loss_cycle_B0 + loss_cycle_B1
        loss_G.backward()

        self.fake_B0 = fake_B0
        self.fake_B1 = fake_B1
        self.fake_B2 = fake_B2
        self.pred_B2 = pred_B2

        self.fake_A0 = fake_A0
        self.fake_A1 = fake_A1
        self.fake_A2 = fake_A2
        self.pred_A2 = pred_A2

        self.rec_A = rec_A
        self.rec_B = rec_B

        self.loss_G_A = loss_G_A0 + loss_G_A1 + loss_G_A2
        self.loss_G_B = loss_G_B0 + loss_G_B1 + loss_G_B2
        self.loss_recycle_A = loss_recycle_A
        self.loss_recycle_B = loss_recycle_B
        self.loss_pred_A = loss_pred_A
        self.loss_pred_B = loss_pred_B

        self.loss_cycle_A = loss_cycle_A0 + loss_cycle_A1
        self.loss_cycle_B = loss_cycle_B0 + loss_cycle_B1

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):

        ret_errors = OrderedDict(
            [('D_A', self.loss_D_A), ('G_A', self.loss_G_A),
             ('Recyc_A', self.loss_recycle_A), ('Pred_A', self.loss_pred_A),
             ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B),
             ('G_B', self.loss_G_B), ('Recyc_B', self.loss_recycle_B),
             ('Pred_B', self.loss_pred_B), ('Cyc_B', self.loss_cycle_B)])

        if self.opt.identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A0 = util.tensor2im(self.input_A0)
        real_A1 = util.tensor2im(self.input_A1)
        real_A2 = util.tensor2im(self.input_A2)

        fake_B0 = util.tensor2im(self.fake_B0)
        fake_B1 = util.tensor2im(self.fake_B1)
        fake_B2 = util.tensor2im(self.fake_B2)

        rec_A = util.tensor2im(self.rec_A)

        real_B0 = util.tensor2im(self.input_B0)
        real_B1 = util.tensor2im(self.input_B1)
        real_B2 = util.tensor2im(self.input_B2)

        fake_A0 = util.tensor2im(self.fake_A0)
        fake_A1 = util.tensor2im(self.fake_A1)
        fake_A2 = util.tensor2im(self.fake_A2)

        rec_B = util.tensor2im(self.rec_B)

        pred_A2 = util.tensor2im(self.pred_A2)
        pred_B2 = util.tensor2im(self.pred_B2)

        ret_visuals = OrderedDict([('real_A0', real_A0), ('fake_B0', fake_B0),
                                   ('real_A1', real_A1), ('fake_B1', fake_B1),
                                   ('fake_B2', fake_B2), ('rec_A', rec_A),
                                   ('real_A2', real_A2),
                                   ('real_B0', real_B0), ('fake_A0', fake_A0),
                                   ('real_B1', real_B1), ('fake_A1', fake_A1),
                                   ('fake_A2', fake_A2), ('rec_B', rec_B),
                                   ('real_B2', real_B2),
                                   ('real_A2', real_A2), ('pred_A2', pred_A2),
                                   ('real_B2', real_B2), ('pred_B2', pred_B2)])
        if self.opt.isTrain and self.opt.identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.netP_A, 'P_A', label, self.gpu_ids)
        self.save_network(self.netP_B, 'P_B', label, self.gpu_ids)
class CycleGANcdModel(BaseModel):
    def name(self):
        return 'CycleGANcdModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> C -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> C -> B)')
            parser.add_argument(
                '--lambda_C',
                type=float,
                default=10.0,
                help='weight for cycle loss (C -> A -> C) and (C -> B -> C)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'D_C_A', 'D_C_B', 'G_C', 'cycle_C', 'cycle_C', 'idt_C_A', 'idt_C_B'
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        visual_names_C = [
            'real_C', 'fake_C_A', 'fake_C_B', 'rec_C_A', 'rec_C_B'
        ]
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')
            visual_names_C.append('idt_C_A')
            visual_names_C.append('idt_C_B')

        self.visual_names = visual_names_A + visual_names_B + visual_names_C
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = [
                'G_A', 'G_B', 'D_A', 'D_B', 'G_C_A', 'G_C_B', 'D_C'
            ]
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B', 'G_C_A', 'G_C_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        self.netG_C_A = networks.define_G(opt.input_nc, opt.input_nc, opt.ngf,
                                          opt.netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, self.gpu_ids)
        self.netG_C_B = networks.define_G(opt.output_nc, opt.output_nc,
                                          opt.ngf, opt.netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

            self.netD_C = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            self.fake_C_A_pool = ImagePool(opt.pool_size)
            self.fake_C_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters(),
                self.netG_C_A.parameters(), self.netG_C_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters(),
                self.netD_C.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.real_C = input['C'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        if self.isTrain:
            self.fake_C_A = self.netG_A(self.real_A)
            self.rec_A = self.netG_C_A(self.fake_C_A)
            self.fake_A = self.netG_C_A(self.real_C)
            self.rec_C_A = self.netG_A(self.fake_A)

            # self.real_C.detach_()

            self.fake_C_B = self.netG_B(self.real_B)
            self.rec_B = self.netG_C_B(self.fake_C_B)
            self.fake_B = self.netG_C_B(self.real_C)
            self.rec_C_B = self.netG_B(self.fake_B)
            # self.fake_B = self.netG_A(self.real_A)
            # self.rec_A = self.netG_B(self.fake_B)

            # self.fake_A = self.netG_B(self.real_B)
            # self.rec_B = self.netG_A(self.fake_A)
        else:
            self.fake_C_A = self.netG_A(self.real_A)
            self.rec_A = self.netG_C_A(self.fake_C_A)
            self.fake_A = self.netG_C_A(self.real_C)
            self.rec_C_A = self.netG_A(self.fake_A)

            # self.real_C.detach_()

            self.fake_C_B = self.netG_B(self.real_B)
            self.rec_B = self.netG_C_B(self.fake_C_B)
            self.fake_B = self.netG_C_B(self.fake_C_A)
            self.rec_C_B = self.netG_B(self.fake_B)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_D_C(self):
        fake_C_A = self.fake_C_A_pool.query(self.fake_C_A)
        self.loss_D_C_A = self.backward_D_basic(self.netD_C, self.real_C,
                                                fake_C_A)

        fake_C_B = self.fake_C_B_pool.query(self.fake_C_B)
        self.loss_D_C_B = self.backward_D_basic(self.netD_C, self.real_C,
                                                fake_C_B)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_C = self.opt.lambda_C
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_C)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_C) * lambda_C * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_C)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_C) * lambda_C * lambda_idt
            # G_C_A/B should be identity if real_A/B is fed.
            self.idt_C_A = self.netG_C_A(self.real_A)
            self.loss_idt_C_A = self.criterionIdt(
                self.idt_C_A, self.real_A) * lambda_A * lambda_idt / 2.
            self.idt_C_B = self.netG_C_B(self.real_B)
            self.loss_idt_C_B = self.criterionIdt(
                self.idt_C_B, self.real_B) * lambda_A * lambda_idt / 2.
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0
            self.loss_idt_C_A = 0
            self.loss_idt_C_B = 0

        self.loss_idt = self.loss_idt_A + self.loss_idt_B + self.loss_idt_C_A + self.loss_idt_C_B

        # GAN loss D_A(G_A(A)) different from original code D_A for B and D_B for A, I use D_A for A, D_B for B, and D_C for C
        self.loss_G_C = (
            self.criterionGAN(self.netD_C(self.fake_C_A), True) +
            self.criterionGAN(self.netD_C(self.fake_C_B), True)) / 2.
        # GAN loss D_B(G_B(B))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_A), True)
        # GAN loss D_C(G_C_A/B(C_A/B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_B), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # Backward cycle loss
        self.loss_cycle_C = (
            self.criterionCycle(self.rec_C_A, self.real_C) +
            self.criterionCycle(self.rec_C_B, self.real_C)) * lambda_C / 2.
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_C + self.loss_cycle_A + self.loss_cycle_B + self.loss_cycle_C + self.loss_idt
        self.loss_G.backward()
        # if lambda_idt > 0:
        #     # G_A should be identity if real_B is fed.
        #     self.idt_A = self.netG_A(self.real_B)
        #     self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
        #     # G_B should be identity if real_A is fed.
        #     self.idt_B = self.netG_B(self.real_A)
        #     self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        # else:
        #     self.loss_idt_A = 0
        #     self.loss_idt_B = 0

        # # GAN loss D_A(G_A(A))
        # self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # # GAN loss D_B(G_B(B))
        # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # # Forward cycle loss
        # self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # # Backward cycle loss
        # self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # # combined loss
        # self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        # self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B, self.netD_C], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B, self.netD_C], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.backward_D_C()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
Пример #30
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A0 = self.Tensor(nb, opt.input_nc, size, size)
        self.input_A1 = self.Tensor(nb, opt.input_nc, size, size)
        self.input_A2 = self.Tensor(nb, opt.input_nc, size, size)

        self.input_B0 = self.Tensor(nb, opt.output_nc, size, size)
        self.input_B1 = self.Tensor(nb, opt.output_nc, size, size)
        self.input_B2 = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        self.which_model_netP = opt.which_model_netP
        if opt.which_model_netP == 'prediction':
            self.netP_A = networks.define_G(opt.input_nc, opt.input_nc,
                                            opt.npf, opt.which_model_netP,
                                            opt.norm, not opt.no_dropout,
                                            opt.init_type, self.gpu_ids)
            self.netP_B = networks.define_G(opt.output_nc, opt.output_nc,
                                            opt.npf, opt.which_model_netP,
                                            opt.norm, not opt.no_dropout,
                                            opt.init_type, self.gpu_ids)
        else:
            self.netP_A = networks.define_G(2 * opt.input_nc, opt.input_nc,
                                            opt.ngf, 'unet_128', opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            self.gpu_ids)
            self.netP_B = networks.define_G(2 * opt.output_nc, opt.output_nc,
                                            opt.ngf, 'unet_128', opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.netP_A, 'P_A', which_epoch)
            self.load_network(self.netP_B, 'P_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(
                itertools.chain(self.netG_A.parameters(),
                                self.netG_B.parameters(),
                                self.netP_A.parameters(),
                                self.netP_B.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        networks.print_network(self.netP_A)
        networks.print_network(self.netP_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')
Пример #31
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                     opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                    opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.AE, 'AE', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionAE = torch.nn.MSELoss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(self.AE.parameters(),
                                                 lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE_GA_GB = torch.optim.Adam(
                itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            networks.print_network(self.AE)
            print('-----------------------------------------------')
Пример #32
0
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.upsample_conv_type, opt.conv_dilation_G, opt.upsample_conv_dilation_G, opt.resnet_activation_G, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.upsample_conv_type, opt.conv_dilation_G, opt.upsample_conv_dilation_G, opt.resnet_activation_G, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                   ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
Пример #34
0
class T2NetModel(BaseModel):
    def name(self):
        return 'T2Net model'

    def initialize(self, opt, labeled_dataset=None, unlabeled_dataset=None):
        BaseModel.initialize(self, opt)

        self.loss_names = [
            'img_rec', 'img_G', 'img_D', 'lab_s', 'lab_t', 'f_G', 'f_D',
            'lab_smooth'
        ]
        self.visual_names = [
            'img_s', 'img_t', 'lab_s', 'lab_t', 'img_s2t', 'img_t2t',
            'lab_s_g', 'lab_t_g'
        ]

        if self.isTrain:
            self.model_names = ['img2task', 's2t', 'img_D', 'f_D']
        else:
            self.model_names = ['img2task', 's2t']

        # define the transform network
        self.net_s2t = network.define_G(opt.image_nc, opt.image_nc, opt.ngf,
                                        opt.transform_layers, opt.norm,
                                        opt.activation, opt.trans_model_type,
                                        opt.init_type, opt.drop_rate, False,
                                        opt.gpu_ids, opt.U_weight)
        # define the task network
        self.net_img2task = network.define_G(opt.image_nc, opt.label_nc,
                                             opt.ngf, opt.task_layers,
                                             opt.norm, opt.activation,
                                             opt.task_model_type,
                                             opt.init_type, opt.drop_rate,
                                             False, opt.gpu_ids, opt.U_weight)

        # define the discriminator
        if self.isTrain:
            self.net_img_D = network.define_D(opt.image_nc, opt.ndf,
                                              opt.image_D_layers, opt.num_D,
                                              opt.norm, opt.activation,
                                              opt.init_type, opt.gpu_ids)
            self.net_f_D = network.define_featureD(opt.image_feature,
                                                   opt.feature_D_layers,
                                                   opt.norm, opt.activation,
                                                   opt.init_type, opt.gpu_ids)

        if self.isTrain:
            self.fake_img_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.l1loss = torch.nn.L1Loss()
            self.nonlinearity = torch.nn.ReLU()
            # initialize optimizers
            self.optimizer_T2Net = torch.optim.Adam([{
                'params':
                filter(lambda p: p.requires_grad, self.net_s2t.parameters())
            }, {
                'params':
                filter(lambda p: p.requires_grad,
                       self.net_img2task.parameters()),
                'lr':
                opt.lr_task,
                'betas': (0.95, 0.999)
            }],
                                                    lr=opt.lr_trans,
                                                    betas=(0.5, 0.9))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                filter(lambda p: p.requires_grad, self.net_img_D.parameters()),
                filter(lambda p: p.requires_grad, self.net_f_D.parameters())),
                                                lr=opt.lr_trans,
                                                betas=(0.5, 0.9))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_T2Net)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(network.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        # initializing GPstruct
        if self.isTrain and opt.gp:
            self.labeled_dataset = labeled_dataset
            self.unlabeled_dataset = unlabeled_dataset
            self.gp_struct = GPStruct(num_lbl=len(labeled_dataset),
                                      num_unlbl=len(unlabeled_dataset),
                                      train_batch_size=self.opt.batch_size,
                                      version=self.opt.version,
                                      kernel_type=self.opt.kernel_type,
                                      pre_trained_enc=opt.pre_trained_enc,
                                      img_size=opt.load_size)

    def set_input(self, input):
        self.input = input
        self.img_source = input['img_source'].cuda(self.gpu_ids[0])
        self.img_target = input['img_target'].cuda(self.gpu_ids[0])
        if self.isTrain:
            self.lab_source = input['lab_source'].cuda(self.gpu_ids[0])
            self.lab_target = input['lab_target'].cuda(self.gpu_ids[0])

        # if len(self.gpu_ids) > 0:
        #     self.img_source = self.img_source.cuda(self.gpu_ids[0], async=True)
        #     self.img_target = self.img_target.cuda(self.gpu_ids[0], async=True)
        #     if self.isTrain:
        #         self.lab_source = self.lab_source.cuda(self.gpu_ids[0], async=True)
        #         self.lab_target = self.lab_target.cuda(self.gpu_ids[0], async=True)

    def forward(self):
        self.img_s = Variable(self.img_source)
        self.img_t = Variable(self.img_target)
        self.lab_s = Variable(self.lab_source)
        self.lab_t = Variable(self.lab_target)

    def backward_D_basic(self, netD, real, fake):

        D_loss = 0
        for (real_i, fake_i) in zip(real, fake):
            # Real
            D_real = netD(real_i.detach())
            # fake
            D_fake = netD(fake_i.detach())

            for (D_real_i, D_fake_i) in zip(D_real, D_fake):
                D_loss += (torch.mean((D_real_i - 1.0)**2) + torch.mean(
                    (D_fake_i - 0.0)**2)) * 0.5

        D_loss.backward()

        return D_loss

    def backward_D_image(self):
        network._freeze(self.net_s2t, self.net_img2task, self.net_f_D)
        network._unfreeze(self.net_img_D)
        size = len(self.img_s2t)
        fake = []
        for i in range(size):
            fake.append(self.fake_img_pool.query(self.img_s2t[i]))
        real = task.scale_pyramid(self.img_t, size)
        self.loss_img_D = self.backward_D_basic(self.net_img_D, real, fake)

    def backward_D_feature(self):
        network._freeze(self.net_s2t, self.net_img2task, self.net_img_D)
        network._unfreeze(self.net_f_D)
        self.loss_f_D = self.backward_D_basic(self.net_f_D, [self.lab_f_t],
                                              [self.lab_f_s])

    def foreward_G_basic(self, net_G, img_s, img_t):

        img = torch.cat([img_s, img_t], 0)
        fake = net_G(img)

        size = len(fake)

        f_s, f_t = fake[0].chunk(2)
        img_fake = fake[1:]

        img_s_fake = []
        img_t_fake = []

        for img_fake_i in img_fake:
            img_s, img_t = img_fake_i.chunk(2)
            img_s_fake.append(img_s)
            img_t_fake.append(img_t)

        return img_s_fake, img_t_fake, f_s, f_t, size

    def backward_synthesis2real(self):

        # image to image transform
        network._freeze(self.net_img2task, self.net_img_D, self.net_f_D)
        network._unfreeze(self.net_s2t)
        self.img_s2t, self.img_t2t, self.img_f_s, self.img_f_t, size = \
            self.foreward_G_basic(self.net_s2t, self.img_s, self.img_t)

        # image GAN loss and reconstruction loss
        img_real = task.scale_pyramid(self.img_t, size - 1)
        G_loss = 0
        rec_loss = 0
        for i in range(size - 1):
            rec_loss += self.l1loss(self.img_t2t[i], img_real[i])
            D_fake = self.net_img_D(self.img_s2t[i])
            for D_fake_i in D_fake:
                G_loss += torch.mean((D_fake_i - 1.0)**2)

        self.loss_img_G = G_loss * self.opt.lambda_gan_img
        self.loss_img_rec = rec_loss * self.opt.lambda_rec_img

        total_loss = self.loss_img_G + self.loss_img_rec

        total_loss.backward(retain_graph=True)

    def backward_translated2depth(self):

        # task network
        network._freeze(self.net_img_D, self.net_f_D)
        network._unfreeze(self.net_s2t, self.net_img2task)
        fake = self.net_img2task.forward(self.img_s2t[-1])

        size = len(fake)
        self.lab_f_s = fake[0]
        self.lab_s_g = fake[1:]

        #feature GAN loss
        D_fake = self.net_f_D(self.lab_f_s)
        G_loss = 0
        for D_fake_i in D_fake:
            G_loss += torch.mean((D_fake_i - 1.0)**2)
        self.loss_f_G = G_loss * self.opt.lambda_gan_feature

        # task loss
        lab_real = task.scale_pyramid(self.lab_s, size - 1)
        task_loss = 0
        for (lab_fake_i, lab_real_i) in zip(self.lab_s_g, lab_real):
            task_loss += self.l1loss(lab_fake_i, lab_real_i)

        self.loss_lab_s = task_loss * self.opt.lambda_rec_lab

        total_loss = self.loss_f_G + self.loss_lab_s

        total_loss.backward()

    def backward_real2depth(self):

        # image2depth
        network._freeze(self.net_s2t, self.net_img_D, self.net_f_D)
        network._unfreeze(self.net_img2task)
        fake = self.net_img2task.forward(self.img_t)
        size = len(fake)

        # Gan depth
        self.lab_f_t = fake[0]
        self.lab_t_g = fake[1:]

        img_real = task.scale_pyramid(self.img_t, size - 1)
        self.loss_lab_smooth = task.get_smooth_weight(
            self.lab_t_g, img_real, size - 1) * self.opt.lambda_smooth

        total_loss = self.loss_lab_smooth

        total_loss.backward()

    def optimize_parameters(self, epoch_iter):

        self.forward()
        # T2Net
        self.optimizer_T2Net.zero_grad()
        self.backward_synthesis2real()
        self.backward_translated2depth()
        self.backward_real2depth()
        self.optimizer_T2Net.step()
        # Discriminator
        self.optimizer_D.zero_grad()
        self.backward_D_feature()
        self.backward_D_image()

        # self.optimizer_D.step()
        # for p in self.net_f_D.parameters():
        #     p.data.clamp_(-0.01,0.01)

        if epoch_iter % 5 == 0:
            self.optimizer_D.step()
            for p in self.net_f_D.parameters():
                p.data.clamp_(-0.01, 0.01)

    def validation_target(self):

        lab_real = task.scale_pyramid(self.lab_t, len(self.lab_t_g))
        task_loss = 0
        for (lab_fake_i, lab_real_i) in zip(self.lab_t_g, lab_real):
            task_loss += task.rec_loss(lab_fake_i, lab_real_i)

        self.loss_lab_t = task_loss * self.opt.lambda_rec_lab

    def generate_fmaps_GP(self):
        self.gp_struct.gen_featmaps(self.labeled_dataset, self.net_img2task,
                                    self.device)
        self.gp_struct.gen_featmaps_unlbl(self.unlabeled_dataset,
                                          self.net_img2task, self.device)

    def optimize_parameters_GP(self, iter, data):
        input_im = data['img_target'].cuda(self.gpu_ids[0])
        # gt = data['lab_target'].cuda(self.device)
        imgid = data['img_target_paths']

        self.optimizer_T2Net.zero_grad()
        network._freeze(self.net_s2t, self.net_img_D, self.net_f_D)
        network._unfreeze(self.net_img2task)
        self.net_img2task.train()

        ### center in
        # outputs = self.netTask(input_im)
        # zy_in = outputs[0]

        ### center_out
        _, zy_in = self.net_img2task(input_im, gp=True)

        loss_gp = self.gp_struct.compute_gploss(zy_in, imgid, iter, 0)
        self.loss_gp = loss_gp * self.opt.lambda_gp
        self.loss_gp.backward()
        self.optimizer_T2Net.step()
Пример #35
0
netD_A_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, opt.finesize, opt.input_nc)
# create discriminator B train function
netD_B_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, opt.finesize, opt.input_nc)

# train loop
time_start = time.time()
how_many_epochs = 5
iteration_count = 0
epoch_count = 0
batch_size = opt.batch_size
display_freq = 10000

netG_A_function = get_generater_function(netG_A)
netG_B_functionr = get_generater_function(netG_B)

fake_A_pool = ImagePool()
fake_B_pool = ImagePool()

while epoch_count < how_many_epochs:
    target_label = np.zeros((batch_size, 1))
    epoch_count, A, B = next(train_batch)

    tmp_fake_B = netG_A_function([A])[0]
    tmp_fake_A = netG_B_functionr([B])[0]

    _fake_B = fake_B_pool.query(tmp_fake_B)
    _fake_A = fake_A_pool.query(tmp_fake_A)

    netG_train_function.train_on_batch([A, B], target_label)

    netD_B_train_function.train_on_batch([B, _fake_B], target_label)
Пример #36
0
    def initialize(self, opt, labeled_dataset=None, unlabeled_dataset=None):
        BaseModel.initialize(self, opt)

        self.loss_names = [
            'img_rec', 'img_G', 'img_D', 'lab_s', 'lab_t', 'f_G', 'f_D',
            'lab_smooth'
        ]
        self.visual_names = [
            'img_s', 'img_t', 'lab_s', 'lab_t', 'img_s2t', 'img_t2t',
            'lab_s_g', 'lab_t_g'
        ]

        if self.isTrain:
            self.model_names = ['img2task', 's2t', 'img_D', 'f_D']
        else:
            self.model_names = ['img2task', 's2t']

        # define the transform network
        self.net_s2t = network.define_G(opt.image_nc, opt.image_nc, opt.ngf,
                                        opt.transform_layers, opt.norm,
                                        opt.activation, opt.trans_model_type,
                                        opt.init_type, opt.drop_rate, False,
                                        opt.gpu_ids, opt.U_weight)
        # define the task network
        self.net_img2task = network.define_G(opt.image_nc, opt.label_nc,
                                             opt.ngf, opt.task_layers,
                                             opt.norm, opt.activation,
                                             opt.task_model_type,
                                             opt.init_type, opt.drop_rate,
                                             False, opt.gpu_ids, opt.U_weight)

        # define the discriminator
        if self.isTrain:
            self.net_img_D = network.define_D(opt.image_nc, opt.ndf,
                                              opt.image_D_layers, opt.num_D,
                                              opt.norm, opt.activation,
                                              opt.init_type, opt.gpu_ids)
            self.net_f_D = network.define_featureD(opt.image_feature,
                                                   opt.feature_D_layers,
                                                   opt.norm, opt.activation,
                                                   opt.init_type, opt.gpu_ids)

        if self.isTrain:
            self.fake_img_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.l1loss = torch.nn.L1Loss()
            self.nonlinearity = torch.nn.ReLU()
            # initialize optimizers
            self.optimizer_T2Net = torch.optim.Adam([{
                'params':
                filter(lambda p: p.requires_grad, self.net_s2t.parameters())
            }, {
                'params':
                filter(lambda p: p.requires_grad,
                       self.net_img2task.parameters()),
                'lr':
                opt.lr_task,
                'betas': (0.95, 0.999)
            }],
                                                    lr=opt.lr_trans,
                                                    betas=(0.5, 0.9))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                filter(lambda p: p.requires_grad, self.net_img_D.parameters()),
                filter(lambda p: p.requires_grad, self.net_f_D.parameters())),
                                                lr=opt.lr_trans,
                                                betas=(0.5, 0.9))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_T2Net)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(network.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        # initializing GPstruct
        if self.isTrain and opt.gp:
            self.labeled_dataset = labeled_dataset
            self.unlabeled_dataset = unlabeled_dataset
            self.gp_struct = GPStruct(num_lbl=len(labeled_dataset),
                                      num_unlbl=len(unlabeled_dataset),
                                      train_batch_size=self.opt.batch_size,
                                      version=self.opt.version,
                                      kernel_type=self.opt.kernel_type,
                                      pre_trained_enc=opt.pre_trained_enc,
                                      img_size=opt.load_size)
Пример #37
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none': # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else 3
        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                if self.opt.verbose:
                    print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [{'params':[value],'lr':opt.lr}]
                    else:
                        params += [{'params':[value],'lr':0.0}]                            
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
Пример #38
0
    def initialize(self, opt):
        #initialize the base class with given parameter set opt
        BaseModel.initialize(self, opt)
        #get the type of the program(train or test)
        self.isTrain = opt.isTrain

        # load/define Generator
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type)
        #define the Discriminator
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)
        #deploy generator to device
        self.netG = self.netG.to(self.device)

        #deploy discriminator to device
        if self.isTrain:
            self.netD = self.netD.to(self.device)

        #if the program is for training
        if self.isTrain:
            #set the size of image buffer that stores previously generated images
            self.fake_AB_pool = ImagePool(opt.pool_size)

            #set initial learning rate for adam
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 device=self.device)

            self.criterionL1 = torch.nn.L1Loss().to(self.device)

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            #define the optimizer for generator
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            #define the optimizer for discriminator
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            #save the optimizers
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            #save schedulers
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
Пример #39
0
class PairModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.opt = opt
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        if opt.vgg > 0:
            self.vgg_loss = networks.PerceptualLoss()
            self.vgg_loss.cuda()
            self.vgg = networks.load_vgg16("./model")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        skip = True if opt.skip > 0 else False
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=skip,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=False,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            if opt.use_wgan:
                self.criterionGAN = networks.DiscLossWGANGP()
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            if opt.use_mse:
                self.criterionCycle = torch.nn.MSELoss()
            else:
                self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        if opt.isTrain:
            self.netG_A.train()
            self.netG_B.train()
        else:
            self.netG_A.eval()
            self.netG_B.eval()
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        if self.opt.skip == 1:
            self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A)
        else:
            self.rec_B = self.netG_A.forward(self.fake_A)

    def predict(self):
        self.real_A = Variable(self.input_A, volatile=True)
        # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        if self.opt.skip == 1:
            latent_real_A = util.tensor2im(self.latent_real_A.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ("latent_real_A", latent_real_A),
                                ("rec_A", rec_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ("rec_A", rec_A)])

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        if self.opt.use_wgan:
            loss_D_real = pred_real.mean()
        else:
            loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        if self.opt.use_wgan:
            loss_D_fake = pred_fake.mean()
        else:
            loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        if self.opt.use_wgan:
            loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(
                netD, real.data, fake.data)
        else:
            loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            if self.opt.skip == 1:
                self.idt_A, _ = self.netG_A.forward(self.real_B)
            else:
                self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        # = self.latent_real_A + self.opt.skip * self.real_A
        pred_fake = self.netD_A.forward(self.fake_B)
        if self.opt.use_wgan:
            self.loss_G_A = -pred_fake.mean()
        else:
            self.loss_G_A = self.criterionGAN(pred_fake, True)
        self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1
        if self.opt.use_wgan:
            self.loss_G_B = -pred_fake.mean()
        else:
            self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss

        if lambda_A > 0:
            self.rec_A = self.netG_B.forward(self.fake_B)
            self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                    self.real_A) * lambda_A
        else:
            self.loss_cycle_A = 0
        # Backward cycle loss

        # = self.latent_fake_A + self.opt.skip * self.fake_A
        if lambda_B > 0:
            if self.opt.skip == 1:
                self.rec_B, self.latent_fake_A = self.netG_A.forward(
                    self.fake_A)
            else:
                self.rec_B = self.netG_A.forward(self.fake_A)
            self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                    self.real_B) * lambda_B
        else:
            self.loss_cycle_B = 0
        self.loss_vgg_a = self.vgg_loss.compute_vgg_loss(
            self.vgg, self.fake_A,
            self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0
        self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(
            self.vgg, self.fake_B,
            self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \
                        self.loss_vgg_a + self.loss_vgg_b + \
                        self.loss_idt_A + self.loss_idt_B
        # self.loss_G = self.L1_AB + self.L1_BA
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        L1 = (self.L1_AB + self.L1_BA).data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0]
               ) / self.opt.vgg if self.opt.vgg > 0 else 0
        if self.opt.identity > 0:
            idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0]
            if self.opt.lambda_A > 0.0:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('Cyc_A', Cyc_A), ('D_B', D_B),
                                    ('G_B', G_B), ('Cyc_B', Cyc_B),
                                    ("vgg", vgg), ("idt", idt)])
            else:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg),
                                   ("idt", idt))
        else:
            if self.opt.lambda_A > 0.0:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('Cyc_A', Cyc_A), ('D_B', D_B),
                                    ('G_B', G_B), ('Cyc_B', Cyc_B),
                                    ("vgg", vgg)])
            else:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg))

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        if self.opt.skip > 0:
            latent_real_A = util.tensor2im(self.latent_real_A.data)

        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)

        if self.opt.identity > 0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
        else:
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):

        if self.opt.new_lr:
            lr = self.old_lr / 2
        else:
            lrd = self.opt.lr / self.opt.niter_decay
            lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #40
0
class ObjectVariedGANModel(BaseModel):
    def name(self):
        return 'ObjectVariedGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(no_dropout=True)
        parser.add_argument('--set_order', type=str, default='decreasing', help='order of segmentation')
        parser.add_argument('--ins_max', type=int, default=1, help='maximum number of object to forward')
        parser.add_argument('--ins_per', type=int, default=1, help='number of object to forward, for one pass')
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_idt', type=float, default=1.0, help='use identity mapping. Setting lambda_idt other than 0 has an effect of scaling the weight of the identity mapping loss')
            parser.add_argument('--lambda_ctx', type=float, default=1.0, help='use context preserving. Setting lambda_ctx other than 0 has an effect of scaling the weight of the context preserving loss')
            parser.add_argument('--lambda_fs', type=float, default=10.0, help='use feature similarity. Setting lambda_fs other than 0 has an effect of scaling the weight of the feature similarity loss')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        self.ins_iter = self.opt.ins_max // self.opt.ins_per  				# number of forward iteration, self.ins_iter=4//2,所以self.ins_iter=2
                                                                            # “//”,在python中,整数除法,这个叫“地板除”,3//2=1

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cyc_A', 'idt_A', 'ctx_A', 'fs_A', 'D_B', 'G_B', 'cyc_B', 'idt_B', 'ctx_B', 'fs_B']

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A_img = ['real_A_img', 'fake_B_img', 'rec_A_img']
        visual_names_B_img = ['real_B_img', 'fake_A_img', 'rec_B_img']
        visual_names_A_seg = ['real_A_seg', 'fake_B_seg', 'rec_A_seg']
        visual_names_B_seg = ['real_B_seg', 'fake_A_seg', 'rec_B_seg']
        self.visual_names = visual_names_A_img + visual_names_A_seg + visual_names_B_img + visual_names_B_seg

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:													#isTrain:True时表示是执行了train.py,否则执行了test.py
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']					#isTrain为True时,保存生成器和判别器
        else:
            self.model_names = ['G_A', 'G_B']								#isTrain为False时,只保存生成器

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc,  opt.ins_per, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)	# opt.norm默认是'instance'
        self.netG_B = networks.define_G(opt.output_nc,  opt.ins_per, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,  opt.ins_per, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc,  opt.ins_per, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)	# '--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images'
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)	# 通过opt.no_lsgan控制,使用MSEloss或者BSEloss
            self.criterionCyc = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # 以下初始化optimizer涉及两个函数,filter()和lambda
            # filter() 函数
            # 	用于过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表。
            # 	该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。python3中filter返回迭代器对象
            # lambda p: p.requires_grad
            # 	这里匿名函数,p是参数,p.required_grad是表达式

            # initialize optimizers
            # 这里的filter,第一个为函数(匿名函数),第二个为序列(包含netG_A和netG_B的所有parameter),返回这些parameter中符合requires_grad=True的parameter。
            # 相当于,网络中所有参数,只有当requires_grad为True的时候,该参数才传给Adam()
            self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(self.netD_A.parameters(), self.netD_B.parameters())), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def select_masks(self, segs_batch):
        """Select object masks to use"""
        if self.opt.set_order == 'decreasing':
            return self.select_masks_decreasing(segs_batch)
        elif self.opt.set_order == 'random':
            return self.select_masks_random(segs_batch)
        else:
            raise NotImplementedError('Set order name [%s] is not recognized' % self.opt.set_order)

    def select_masks_decreasing(self, segs_batch):
        """Select masks in decreasing order"""
        ret = list()
        for segs in segs_batch:

            mean = segs.mean(-1).mean(-1)		# mean的size是torch.Size([20])
                                                # 这里做了两次mean处理,都是在最后一维进行处理,
            m, i = mean.topk(self.opt.ins_max)	# m是:tensor([-0.7352, -0.7675, -1.0000, -1.0000]),大小是torch.Size([4])。
                                                # i是tensor([0, 1, 5, 3]),大小是torch.Size([4]),i可能表示前四个大的seg的索引
                                                # '--ins_max', type=int, default=4, help='maximum number of object to forward'
            ret.append(segs[i, :, :])			# ret是list,其中每个元素shape是torch.Size([4, 200, 200])

        return torch.stack(ret)					# torch.stack表示在新的dim上concatenate。
                                                # 返回的是torch.Size([1, 4, 200, 200])

    def select_masks_random(self, segs_batch):
        """Select masks in random order"""
        ret = list()
        for segs in segs_batch:
            mean = (segs + 1).mean(-1).mean(-1)															# torch.Size([20])
            m, i = mean.topk(self.opt.ins_max)
            num = min(len(mean.nonzero()), self.opt.ins_max)											# num = {int}2
            reorder = np.concatenate((np.random.permutation(num), np.arange(num, self.opt.ins_max)))	# reorder = {ndarry}[0 1 2 3]
            ret.append(segs[i[reorder], :, :])															# ret是list,其中每个元素shape是torch.Size([4, 200, 200])
        return torch.stack(ret)

    def merge_masks(self, segs):
        """Merge masks (B, N, W, H) -> (B, 1, W, H)"""
        ret = torch.sum((segs + 1)/2, dim=1, keepdim=True)  				# (B, 1, W, H)
        return ret.clamp(max=1, min=0) * 2 - 1

    def get_weight_for_ctx(self, x, y):
        """Get weight for context preserving loss"""
        z = self.merge_masks(torch.cat([x, y], dim=1))
        return (1 - z) / 2  # [-1,1] -> [1,0]

    def weighted_L1_loss(self, src, tgt, weight):
        """L1 loss with given weight (used for context preserving loss)"""
        return torch.mean(weight * torch.abs(src - tgt))

    def get_weight_for_cx(self, x, y):
        """Get weight for context preserving loss"""
        z = self.merge_masks(torch.cat([x, y], dim=1))
        return (1 - z) / 2  # [-1,1] -> [1,0]

    def multiply_cx(self, src, weight):
        """L1 loss with given weight (used for context preserving loss)"""
        return torch.mean(weight * torch.abs(src))

    def split(self, x):
        """Split data into image and mask (only assume 3-channel image)"""
        return x[:, :3, :, :], x[:, 3:, :, :]								# 前三通道是image的,剩余通道是mask的

    # input是数据集实例(类UnalignedSegDataset的实例)
    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
                                                                            # input is the datasets, we use input[idx]to get the item.
                                                                            # eg.input['A'] or input['B'] or input['A_segs'] or input['B_segs']
                                                                            # refer to the "data/unaligned_seg_dataset.py' and see the get_item return the map data
        self.real_A_img = input['A' if AtoB else 'B'].to(self.device)		# self.real_A_img的shape是torch.Size([1, 3, 256, 256]),一张原图,3通道
        self.real_B_img = input['B' if AtoB else 'A'].to(self.device)

        real_A_segs = input['A_segs' if AtoB else 'B_segs']					# real_A_segs是domainA(当AtoB时)中的一张图对应的多张segs,所有segs拼接使用cat函数
        real_B_segs = input['B_segs' if AtoB else 'A_segs']

        self.real_A_segs = self.select_masks(real_A_segs).to(self.device)	# self.real_A_segs的shape是torch.Size([1, 4, 200, 200]),四张seg
        self.real_B_segs = self.select_masks(real_B_segs).to(self.device)

        self.real_A = torch.cat([self.real_A_img, self.real_A_segs], dim=1)	# self.real_A的shape是torch.Size([1, 7, 200, 200]),融合了一张原图和四张seg
        self.real_B = torch.cat([self.real_B_img, self.real_B_segs], dim=1)

        self.real_A_seg = self.merge_masks(self.real_A_segs)  				# merged mask,Merge masks (B, N, W, H) -> (B, 1, W, H)# self.real_A_seg的shape是torch.Size([1, 1, 200, 200]),相当于将其压缩,将7压缩为1
        self.real_B_seg = self.merge_masks(self.real_B_segs)

        self.image_paths = input['A_paths' if AtoB else 'B_paths']			# A_paths是一个list,但是其长度为1,值为'./datasets/shp2gir_coco/trainA/788.png'

    def forward(self, idx=0):
        N = self.opt.ins_per												# '--ins_per', type=int, default=2, help='number of object to forward, for one pass')	# 一次迭代中,使用到的object的数目

        self.real_A_seg_sng = self.real_A_segs[:, N*idx:N*(idx+1), :, :]  	# ith mask,似乎取第i批mask,一批有ins_iter张(2张)。sng应该表示single的意思。
                                                                            # self.real_A_segs的shape是torch.Size([1, 4, 200, 200]),四张seg
        self.real_B_seg_sng = self.real_B_segs[:, N*idx:N*(idx+1), :, :]  	# ith mask
        empty = -torch.ones(self.real_A_seg_sng.size()).to(self.device)  	# empty image

        self.forward_A = (self.real_A_seg_sng + 1).sum() > 0  				# check if there are remaining object
                                                                            # 当forward_A=1时,才前馈并进反向传播
                                                                            # 因为在read_segs()中若seg不存在,则每个像素设置为-1。所以这里(self.real_A_seg_sng + 1)?
        self.forward_B = (self.real_B_seg_sng + 1).sum() > 0  				# check if there are remaining object

        # forward A
        if self.forward_A:
            self.real_A_fuse_sng = torch.cat([self.real_A_img_sng, self.real_A_seg_sng], dim=1)



            self.fake_B_fuse_sng = self.netG_A(self.real_A_fuse_sng)  # (原图image和掩码)即(self.real_A_sng)作为一个整体输入到生成器
            self.fake_B_img_sng, self.fake_B_seg_sng = self.split(self.fake_B_fuse_sng)
            self.rec_A_fuse_sng = self.netG_B(self.fake_B_fuse_sng)  # 生成的假的domain B的图(self.fake_B_sng),再输入到G_B进行reconstruc
            self.rec_A_img_sng, self.rec_A_seg_sng = self.split(self.rec_A_fuse_sng)

            self.fake_B_seg_mul = self.fake_B_seg_sng
            self.fake_B_mul = self.fake_B_fuse_sng  # self.fake_B_mul是假的domainB的结果,用于计算loss




        # forward B
        if self.forward_B:

            self.real_B_fuse_sng = torch.cat([self.real_B_img_sng, self.real_B_seg_sng], dim=1)
            self.fake_A_fuse_sng = self.netG_B(self.real_B_fuse_sng)
            self.fake_A_img_sng, self.fake_A_seg_sng = self.split(self.fake_A_fuse_sng)

            self.rec_B_fuse_sng = self.netG_A(self.fake_A_fuse_sng)
            self.rec_B_img_sng, self.rec_B_seg_sng = self.split(self.rec_B_fuse_sng)

            self.fake_A_seg_mul = self.fake_A_seg_sng
            self.fake_A_mul = self.fake_A_fuse_sng


    def test(self):															# 用于test.py
        # init setting														# 与optimize_parameters()相同的初始化
        self.real_A_img_sng = self.real_A_img								# self.real_A_img的shape是torch.Size([1, 3, 200, 200]),一张原图,3通道
        self.real_B_img_sng = self.real_B_img
        self.fake_A_seg_list = list()
        self.fake_B_seg_list = list()
        self.rec_A_seg_list = list()
        self.rec_B_seg_list = list()

        # sequential mini-batch translation
        for i in range(self.ins_iter):
            # forward
            with torch.no_grad():  											# no grad,注意!test的时候没有更新参数,所以forward的时候设置:no grad
                self.forward(i)

            # update setting for next iteration
            self.real_A_img_sng = self.fake_B_img_sng.detach()
            self.real_B_img_sng = self.fake_A_img_sng.detach()
            self.fake_A_seg_list.append(self.fake_A_seg_sng.detach())
            self.fake_B_seg_list.append(self.fake_B_seg_sng.detach())
            self.rec_A_seg_list.append(self.rec_A_seg_sng.detach())
            self.rec_B_seg_list.append(self.rec_B_seg_sng.detach())

            # save visuals
            if i == 0:  # first
                self.rec_A_img = self.rec_A_img_sng
                self.rec_B_img = self.rec_B_img_sng
            if i == self.ins_iter - 1:  # last
                self.fake_A_img = self.fake_A_img_sng
                self.fake_B_img = self.fake_B_img_sng
                self.fake_A_seg = self.merge_masks(self.fake_A_seg_mul)
                self.fake_B_seg = self.merge_masks(self.fake_B_seg_mul)
                self.rec_A_seg = self.merge_masks(torch.cat(self.rec_A_seg_list, dim=1))
                self.rec_B_seg = self.merge_masks(torch.cat(self.rec_B_seg_list, dim=1))

    def backward_G(self):													# 计算生成器的总loss并反向传播
        lambda_A = self.opt.lambda_A										# 用于backward A
        lambda_B = self.opt.lambda_B										# 用于backward B
        lambda_idt = self.opt.lambda_idt									# 用于loss_idt_A和loss_idt_B
        lambda_ctx = self.opt.lambda_ctx
        lambda_fs = self.opt.lambda_fs

        # backward A
        if self.forward_A:

            self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B_mul), True)
            self.loss_cyc_A = self.criterionCyc(self.rec_A_fuse_sng, self.real_A_fuse_sng) * lambda_A

            self.fake_A_fuse_sng_idt = self.netG_B(self.real_A_fuse_sng)
            self.fake_A_img_idt, self.fake_A_seg_idt = self.split(self.fake_A_fuse_sng_idt)
            self.loss_idt_B = self.criterionIdt(self.fake_A_fuse_sng_idt,
                                                self.real_A_fuse_sng.detach()) * lambda_A * lambda_idt

            weight_A = self.get_weight_for_ctx(self.real_A_seg_sng, self.fake_B_seg_sng)
            self.loss_ctx_A = self.weighted_L1_loss(self.real_A_img_sng, self.fake_B_img_sng,
                                                    weight=weight_A) * lambda_A * lambda_ctx

            layers = {"conv_1_1": 1.0,"conv_3_2": 1.0}
            I = self.fake_B_img_sng # 生成的B域的图
            T = self.real_B_img_sng # 目标域B的真实图
            I_multiply = self.fake_B_seg_mul * I
            T_multiply = self.real_B_seg_sng * T

            feature_similarity_loss = Feature_Similarity_Loss(layers, max_1d_size=64).cuda()
            # print('fsloss_A', feature_similarity_loss(I_multiply, T_multiply))
            self.loss_fs_A = feature_similarity_loss(I_multiply, T_multiply)[0] * lambda_fs
        else:
            self.loss_G_A = 0
            self.loss_cyc_A = 0
            self.loss_idt_B = 0
            self.loss_ctx_A = 0
            self.loss_fs_A = 0

        # backward B
        if self.forward_B:
            self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A_mul), True)
            self.loss_cyc_B = self.criterionCyc(self.rec_B_fuse_sng, self.real_B_fuse_sng) * lambda_B
            self.fake_B_fuse_sng_idt = self.netG_A(self.real_B_fuse_sng)
            self.fake_B_img_idt, self.fake_B_seg_idt = self.split(self.fake_B_fuse_sng_idt)
            self.loss_idt_A = self.criterionIdt(self.fake_B_fuse_sng_idt,
                                                self.real_B_fuse_sng.detach()) * lambda_B * lambda_idt

            weight_B = self.get_weight_for_ctx(self.real_B_seg_sng, self.fake_A_seg_sng)
            self.loss_ctx_B = self.weighted_L1_loss(self.real_B_img_sng, self.fake_A_img_sng,
                                                    weight=weight_B) * lambda_B * lambda_ctx

            layers = {"conv_1_1": 1.0, "conv_3_2": 1.0}
            I = self.fake_A_img_sng # 生成的B域的图
            T = self.real_A_img_sng # 目标域B的真实图
            I_multiply = self.fake_A_seg_mul * I
            T_multiply = self.real_A_seg_sng * T

            feature_similarity_loss = Feature_Similarity_Loss(layers, max_1d_size=64).cuda()
            # print('fsloss_B', feature_similarity_loss(I_multiply, T_multiply))
            self.loss_fs_B = feature_similarity_loss(I_multiply, T_multiply)[0] * lambda_fs
        else:
            self.loss_G_B = 0
            self.loss_cyc_B = 0
            self.loss_idt_A = 0
            self.loss_ctx_B = 0
            self.loss_fs_B = 0

        # combined loss
        # self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cyc_A + self.loss_cyc_B + self.loss_idt_A + self.loss_idt_B + self.loss_ctx_A + self.loss_ctx_B + self.loss_fs_A + self.loss_fs_B
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cyc_A + self.loss_cyc_B + self.loss_idt_A + self.loss_idt_B + self.loss_fs_A + self.loss_fs_B
        self.loss_G.backward()	# 生成器A和生成器B的各种loss为总G的loss,反向传播

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B_mul)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A_mul)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def optimize_parameters(self):											# 用于train.py,和test()很像
        # init setting														# 与test()相同的初始化
        self.real_A_img_sng = self.real_A_img								# self.real_A_img的shape是torch.Size([1, 3, 200, 200]),一张原图,3通道
        self.real_B_img_sng = self.real_B_img
        self.fake_A_seg_list = list()
        self.fake_B_seg_list = list()
        self.rec_A_seg_list = list()
        self.rec_B_seg_list = list()

        # sequential mini-batch translation
        for i in range(self.ins_iter):
            # forward
            self.forward(i)

            # G_A and G_B													# 比test多出的部分
            if self.forward_A or self.forward_B:
                self.set_requires_grad([self.netD_A, self.netD_B], False)	# 为什么设置判别器A和判别器B的参数不需要更新?
                self.optimizer_G.zero_grad()
                self.backward_G()											# 生成器的loss的反向传播
                self.optimizer_G.step()										# 更新参数

            # D_A and D_B													# 比test多出的部分
            if self.forward_A or self.forward_B:
                self.set_requires_grad([self.netD_A, self.netD_B], True)	# 设置判别器的参数需要更新
                self.optimizer_D.zero_grad()
                if self.forward_A:
                    self.backward_D_A()										# 判别器A的loss的反向传播,为什么判别器要分开反向传播?
                if self.forward_B:
                    self.backward_D_B()										# 判别器B的loss的反向传播
                self.optimizer_D.step()										# 更新参数

            # update setting for next iteration
            self.real_A_img_sng = self.fake_B_img_sng.detach()
            self.real_B_img_sng = self.fake_A_img_sng.detach()
            self.fake_A_seg_list.append(self.fake_A_seg_sng.detach())
            self.fake_B_seg_list.append(self.fake_B_seg_sng.detach())
            self.rec_A_seg_list.append(self.rec_A_seg_sng.detach())
            self.rec_B_seg_list.append(self.rec_B_seg_sng.detach())

            # save visuals
            if i == 0:  # first
                self.rec_A_img = self.rec_A_img_sng
                self.rec_B_img = self.rec_B_img_sng
            if i == self.ins_iter - 1:  # last
                self.fake_A_img = self.fake_A_img_sng
                self.fake_B_img = self.fake_B_img_sng
                self.fake_A_seg = self.merge_masks(self.fake_A_seg_mul)
                self.fake_B_seg = self.merge_masks(self.fake_B_seg_mul)
                self.rec_A_seg = self.merge_masks(torch.cat(self.rec_A_seg_list, dim=1))
                self.rec_B_seg = self.merge_masks(torch.cat(self.rec_B_seg_list, dim=1))
Пример #41
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.opt = opt
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        if opt.vgg > 0:
            self.vgg_loss = networks.PerceptualLoss()
            self.vgg_loss.cuda()
            self.vgg = networks.load_vgg16("./model")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        skip = True if opt.skip > 0 else False
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=skip,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=False,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            if opt.use_wgan:
                self.criterionGAN = networks.DiscLossWGANGP()
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            if opt.use_mse:
                self.criterionCycle = torch.nn.MSELoss()
            else:
                self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        if opt.isTrain:
            self.netG_A.train()
            self.netG_B.train()
        else:
            self.netG_A.eval()
            self.netG_B.eval()
        print('-----------------------------------------------')
Пример #42
0
class FlowRefineModel(BaseModel):
    def name(self):
        return 'PVHMModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        opt.output_nc = opt.input_nc
        # load/define networks
        self.netG = networks.define_G(2, 2, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, tanh=True)
        self.flow_remapper = networks.flow_remapper(size=opt.fineSize, batch=opt.batchSize,gpu_ids=opt.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        grid = np.zeros((opt.fineSize,opt.fineSize,2))

        for i in range(grid.shape[0]):
            for j in range(grid.shape[1]):
                grid[i,j,0] = j
                grid[i,j,1] = i
        grid /= (opt.fineSize/2)
        grid -= 1
        self.grid = torch.from_numpy(grid).cuda().float() #Variable(torch.from_numpy(grid))
        self.grid = self.grid.view(1,self.grid.size(0),self.grid.size(1),self.grid.size(2))
        self.grid = Variable(self.grid)

        intrinsics = np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((1, 3, 3))
        intrinsics_inv = np.linalg.inv(np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3))
        self.intrinsics = Variable(torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand(opt.batchSize,3,3)
        self.intrinsics_inv = Variable(torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand(opt.batchSize,3,3)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        input_C = input['C']
        if len(self.gpu_ids) > 0:
            input_C = input_C.cuda(self.gpu_ids[0], async=True)
        self.input_C = input_C

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)
        self.real_C = Variable(self.input_C)

        pose = np.array([0, 0, 0, 0, -np.pi / 4., 0, ]).reshape((1, 6))
        pose = Variable(torch.from_numpy(pose.astype(np.float32)).cuda()).expand(self.opt.batchSize,6)
        self.forward_map = inverse_warp(self.real_A,self.real_C, pose, self.intrinsics, self.intrinsics_inv)
        self.backward_map = self.flow_remapper(self.forward_map, self.forward_map)

        self.backward_map_refined = self.netG(self.backward_map.permute(0,3,1,2)).permute(0,2,3,1)
        self.fake_B = F.grid_sample(self.real_A, self.backward_map_refined)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)
        self.real_C = Variable(self.input_C)

        pose = np.array([0, 0, 0, 0, -np.pi / 8., 0, ]).reshape((1, 6))
        pose = Variable(torch.from_numpy(pose.astype(np.float32)).cuda()).expand(self.opt.batchSize, 6)
        self.forward_map = inverse_warp(self.real_A, self.real_C, pose, self.intrinsics, self.intrinsics_inv)
        self.backward_map = self.flow_remapper(self.forward_map, self.forward_map)

        self.backward_map_refined = self.netG(self.backward_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        self.fake_B = F.grid_sample(self.real_A, self.backward_map_refined)



        # self.fake_B = self.fake_B_flow
        # self.fake_B = self.fake_B.permute(0, 3, 1, 2)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.opt.lambda_gan * self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN(pred_fake, True)

        # Second, G(A) = B

        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
        # self.loss_G_flow = self.criterionL1(self.forward_flow, self.real_C) * self.opt.lambda_flow
        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward(retain_graph=True)

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])
                            ])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        # real_C = util.tensor2im(self.real_C.data)
        forward_map = util.tensor2im(self.forward_map.permute(0, 3, 1, 2).data)
        backward_map = util.tensor2im(self.backward_map.permute(0, 3, 1, 2).data)
        backward_map_refined = util.tensor2im(self.backward_map_refined.permute(0, 3, 1, 2).data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), \
                            ('forward_map', forward_map), ('backward_map', backward_map),('backward_map_refined', backward_map_refined),])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
Пример #44
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        opt.output_nc = opt.input_nc
        # load/define networks
        self.netG = networks.define_G(2, 2, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, tanh=True)
        self.flow_remapper = networks.flow_remapper(size=opt.fineSize, batch=opt.batchSize,gpu_ids=opt.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        grid = np.zeros((opt.fineSize,opt.fineSize,2))

        for i in range(grid.shape[0]):
            for j in range(grid.shape[1]):
                grid[i,j,0] = j
                grid[i,j,1] = i
        grid /= (opt.fineSize/2)
        grid -= 1
        self.grid = torch.from_numpy(grid).cuda().float() #Variable(torch.from_numpy(grid))
        self.grid = self.grid.view(1,self.grid.size(0),self.grid.size(1),self.grid.size(2))
        self.grid = Variable(self.grid)

        intrinsics = np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((1, 3, 3))
        intrinsics_inv = np.linalg.inv(np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3))
        self.intrinsics = Variable(torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand(opt.batchSize,3,3)
        self.intrinsics_inv = Variable(torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand(opt.batchSize,3,3)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
Пример #45
0
class CycleDRPANModel(BaseModel):
    def name(self):
        return 'CycleDRPANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'R_A', 'GR_A']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.isTrain:
            visual_names_A = ['real_A', 'fake_B', 'rec_A', 'fake_Br', 'real_Ar', 'fake_Bf', 'real_Af']
            visual_names_B = ['real_B', 'fake_A', 'rec_B', 'fake_Ar', 'real_Br', 'fake_Af', 'real_Bf']

        else:
            visual_names_A = ['real_A', 'fake_B', 'rec_A']
            visual_names_B = ['real_B', 'fake_A', 'rec_B']

        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'R_A', 'R_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X), R_A(R_Y), R_B(R_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netR_A = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D,
                                            opt.norm, use_sigmoid,
                                            opt.init_type, opt.init_gain, self.gpu_ids)
            self.netR_B = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D,
                                            opt.norm, use_sigmoid,
                                            opt.init_type, opt.init_gain, self.gpu_ids)


        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_R_A = torch.optim.Adam(self.netR_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_R_B = torch.optim.Adam(self.netR_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.optimizers.append(self.optimizer_R_A)
            self.optimizers.append(self.optimizer_R_B)

            self.proposal = Proposal()

            # self.batchsize = opt.batchSize
            # self.label_r = torch.FloatTensor(self.batchsize)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def reviser_A(self):
        # training with reviser
        for n_step in range(3):
            fake_B_ = self.netG_A(self.real_A)
            output = self.netD_A(fake_B_.detach())

            # proposal
            self.fake_Br, self.real_Ar, self.fake_Bf, self.real_Af, self.fake_ABf, self.real_ABr = self.proposal.forward_A(self.real_B, fake_B_, self.real_A, output)
            # train with real
            self.netD_A.zero_grad()
            output_r = self.netR_A(self.real_ABr.detach())
            self.loss_errR_real_A = self.criterionGAN(output_r, True)
            self.loss_errR_real_A.backward()

            # train with fake
            output_r = self.netR_A(self.fake_ABf.detach())
            self.loss_errR_fake_A = self.criterionGAN(output_r, False)
            self.loss_errR_fake_A.backward()

            self.loss_R_A = (self.loss_errR_real_A + self.loss_errR_fake_A) / 2
            self.optimizer_R_A.step()

            # train Generator with reviser
            self.netG_A.zero_grad()
            output_r = self.netR_A(self.fake_ABf)
            self.loss_GR_A = self.criterionGAN(output_r, True)
            self.loss_GR_A.backward()
            self.optimizer_G.step()

    def reviser_B(self):
        # training with reviser
        for n_step in range(3):
            fake_A_ = self.netG_B(self.real_B)
            output = self.netD_B(fake_A_.detach())

            # proposal
            self.fake_Ar, self.real_Br, self.fake_Af, self.real_Bf, self.fake_BAf, self.real_BAr = self.proposal.forward_B(self.real_A, fake_A_, self.real_B, output)
            # train with real
            self.netD_B.zero_grad()
            output_r = self.netR_B(self.real_BAr.detach())
            self.loss_errR_real_B = self.criterionGAN(output_r, True)
            self.loss_errR_real_B.backward()

            # train with fake
            output_r = self.netR_B(self.fake_BAf.detach())
            self.loss_errR_fake_B = self.criterionGAN(output_r, False)
            self.loss_errR_fake_B.backward()

            self.loss_R_B = (self.loss_errR_real_B + self.loss_errR_fake_B) / 2
            self.optimizer_R_B.step()

            # train Generator with reviser
            self.netG_B.zero_grad()
            output_r = self.netR_B(self.fake_BAf)
            self.errGAN_r = self.criterionGAN(output_r, True)
            self.loss_GR_B = self.errGAN_r
            self.loss_GR_B.backward()
            self.optimizer_G.step()


    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
        # R_A and R_B
        self.set_requires_grad([self.netR_A, self.netR_B], True)
        self.optimizer_R_A.zero_grad()
        self.optimizer_R_B.zero_grad()
        self.reviser_A()
        self.reviser_B()
Пример #46
0
class ReCycleGANModel(BaseModel):
    def name(self):
        return 'ReCycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        assert 'recycle_skips' in opt.which_model_netG

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids, fourier_mode=opt.fourier_mode)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids,
                                            fourier_mode=opt.fourier_mode)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B, is_cycle=True)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A, is_cycle=True)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                   ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)