Ejemplo n.º 1
0
def get_nets(path, which_epoch='latest', def_opt=None):
    gpu_ids = [0]
    Tensor = torch.cuda.FloatTensor
    opt = util.load_opt(path, def_opt)
    # assume caffe style model
    opt.not_caffe = False
    netG_A = networks.define_G(opt.input_nc,
                               opt.output_nc,
                               opt.ngf,
                               opt.which_model_netG,
                               opt.norm,
                               not opt.no_dropout,
                               opt.init_type,
                               gpu_ids,
                               opt=opt)
    util.load_network_with_path(netG_A, 'G_A', path)

    netG_B = networks.define_G(opt.input_nc,
                               opt.output_nc,
                               opt.ngf,
                               opt.which_model_netG,
                               opt.norm,
                               not opt.no_dropout,
                               opt.init_type,
                               gpu_ids,
                               opt=opt)
    util.load_network_with_path(netG_B, 'G_B', path)
    netG_A.cuda()
    netG_B.cuda()
    return {'A': netG_A, 'B': netG_B}
Ejemplo n.º 2
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.target_weight = []
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C_sr = self.Tensor(nb, opt.output_nc, size, size)
        if opt.aux:
            self.A_aux = self.Tensor(nb, opt.input_nc, size, size)
            self.B_aux = self.Tensor(nb, opt.output_nc, size, size)
            self.C_aux = self.Tensor(nb, opt.output_nc, size, size)

        self.netE_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=2)

        mult = self.netE_A.get_mult()

        self.netE_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        64,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=3)

        self.net_D = networks.define_G(opt.input_nc,
                                       opt.output_nc,
                                       opt.ngf,
                                       'ResnetDecoder_my',
                                       opt.norm,
                                       not opt.no_dropout,
                                       opt.init_type,
                                       self.gpu_ids,
                                       opt=opt,
                                       mult=mult)

        mult = self.net_D.get_mult()

        self.net_Dc = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetDecoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult,
                                        n_upsampling=1)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        mult = self.net_Dc.get_mult()

        self.netG_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        #        self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #      set_eval(self.netG_A_running)
        #     accumulate(self.netG_A_running, self.netG_A, 0)
        #        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #    self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc,
        #                                   opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #  set_eval(self.netG_B_running)
        # accumulate(self.netG_B_running, self.netG_B, 0)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
#         self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
#                                          opt.which_model_netD,
#                                        opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
        print('---------- Networks initialized -------------')
        #        networks.print_network(self.netG_B, opt, (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(self.netE_C, opt,
                               (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(
            self.net_D, opt, (opt.ngf * 4, opt.fineSize / 4, opt.fineSize / 4))
        networks.print_network(self.net_Dc, opt,
                               (opt.ngf, opt.CfineSize / 2, opt.CfineSize / 2))
        # networks.print_network(self.netG_B, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            # networks.print_network(self.netD_B, opt)
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print('Loaded model')
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netG_A_running, 'G_A', which_epoch)
                self.load_network(self.netG_B_running, 'G_B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain and opt.load_path != '':
            print('Loaded model from load_path')
            which_epoch = opt.which_epoch
            load_network_with_path(self.netG_A,
                                   'G_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netG_B,
                                   'G_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_A,
                                   'D_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_B,
                                   'D_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.fake_C_pool = ImagePool(opt.pool_size)
            # define loss functions
            if len(self.target_weight) == opt.num_D:
                print(self.target_weight)
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    target_weight=self.target_weight,
                    gan=opt.gan)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    gan=opt.gan)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionColor = networks.ColorLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.netG_A.parameters(), self.net_Dc.parameters(),
                self.netG_C.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                 lr=opt.lr,
                                                 betas=(opt.beta1, 0.999))
            self.optimizer_G_A_sr = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                     lr=opt.lr,
                                                     betas=(opt.beta1, 0.999))
            self.optimizer_AE_sr = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.netG_A.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            #       self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_AE)
            # self.optimizers.append(self.optimizer_G_A_sr)
            self.optimizers.append(self.optimizer_AE_sr)
            self.optimizers.append(self.optimizer_D_A)
            #   self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))
Ejemplo n.º 3
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        #self.down_2 = torch.nn.AvgPool2d(2**int((np.log(opt.fineSize) - np.log(opt.pre.fineSize))/np.log(2)))
        #self.up_2 = torch.nn.Upsample(scale_factor=2**int((np.log(opt.fineSize) - np.log(opt.pre.fineSize))/np.log(2)))
        if not opt.idt:
            self.down_2 = torch.nn.AvgPool2d(2)
            self.up_2 = torch.nn.Upsample(scale_factor=2)
        else:
            self.down_2 = torch.nn.AvgPool2d(1)
            self.up_2 = torch.nn.AvgPool2d(1)
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        n_upsampling=opt.n_upsample,
                                        n_downsampling=opt.n_downsample,
                                        side='A',
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        n_upsampling=opt.n_upsample,
                                        n_downsampling=opt.n_downsample,
                                        side='B',
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
            self.netD_B = networks.define_D(opt.input_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
            if 'stack' in opt.which_model_netD and opt.load_pre:
                netD_A = networks.define_D(opt.output_nc,
                                           opt.ndf,
                                           'n_layers',
                                           opt.n_layers_D,
                                           opt.norm,
                                           use_sigmoid,
                                           opt.init_type,
                                           self.gpu_ids,
                                           opt=opt)
                netD_B = networks.define_D(opt.input_nc,
                                           opt.ndf,
                                           'n_layers',
                                           opt.n_layers_D,
                                           opt.norm,
                                           use_sigmoid,
                                           opt.init_type,
                                           self.gpu_ids,
                                           opt=opt)
                pre_opt = opt
                for i in range(opt.num_D - 2, -1, -1):
                    util.load_network_with_path(netD_A, 'D_A', opt.pre_path)
                    util.load_network_with_path(netD_B, 'D_B', opt.pre_path)
                    exec('self.netD_A.layer%d = netD_A' % i)
                    exec('self.netD_B.layer%d = netD_B' % i)
                    pre_opt = opt.pre

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A,
                               opt,
                               input_shape=(opt.input_nc, opt.fineSize,
                                            opt.fineSize))
        if self.isTrain:
            networks.print_network(self.netD_A,
                                   opt,
                                   input_shape=(3, opt.fineSize, opt.fineSize))
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print('Continue from ', opt.which_epoch)
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain and not opt.test:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions

            # print('Load VGG 16')
            # vgg_model = vgg.vgg16(pretrained=True)
            # if torch.cuda.is_available():
            # vgg_model.cuda()
            # self.loss_network = networks.VGGLossNetwork(vgg_model)
            # self.loss_network.eval()
            # del vgg_model

            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            # if self.opt.cyc_perc:
            # self.criterionCycle = networks.PerceptualLoss(self.loss_network, tensor=self.Tensor)
            # else:
            self.criterionCycle = networks.RECLoss()
            # self.criterionPerc = networks.PerceptualLoss(self.loss_network, tensor=self.Tensor)
            # self.criterionColor = networks.ColorLoss()
            # initialize optimizers
            if opt.tune_pre:
                self.optimizer_G_A = torch.optim.Adam(self.netG_A.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_G_B = torch.optim.Adam(self.netG_B.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            elif opt.alpha_gate != '':
                self.optimizer_G_A = torch.optim.Adam(itertools.chain(
                    self.netG_A.model_in.parameters(),
                    self.netG_A.model_mid.parameters(),
                    self.netG_A.model_out.parameters(),
                    self.netG_A.gate.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_G_B = torch.optim.Adam(itertools.chain(
                    self.netG_B.model_in.parameters(),
                    self.netG_B.model_mid.parameters(),
                    self.netG_B.model_out.parameters(),
                    self.netG_B.gate.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G_A = torch.optim.Adam(itertools.chain(
                    self.netG_A.model_in.parameters(),
                    self.netG_A.model_mid.parameters(),
                    self.netG_A.model_out.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_G_B = torch.optim.Adam(itertools.chain(
                    self.netG_B.model_in.parameters(),
                    self.netG_B.model_mid.parameters(),
                    self.netG_B.model_out.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            if opt.d_lr2:
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=(opt.lr / 2.0),
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=(opt.lr / 2.0),
                                                      betas=(opt.beta1, 0.999))
            else:
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G_A)
            self.optimizers.append(self.optimizer_G_B)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))