Ejemplo n.º 1
0
 def setup(self, args, parser=None):
     if not 'continue_train' in args:
         self.schedulers = [
             get_scheduler(optimizer, args) for optimizer in self.optimizers
         ]
     else:
         self.load_networks(args.which_epoch)
     self.print_networks(args.verbose)
Ejemplo n.º 2
0
    def __init__(self, opt, device):
        super(CycleGAN, self).__init__()

        self.device = device
        self.opt = opt

        self.netG_A = networks.define_G(self.opt.input_nc, self.opt.output_nc,
                                        self.opt.ngf, self.opt.netG,
                                        self.opt.norm, self.opt.dropout,
                                        self.opt.init_type, self.opt.init_gain,
                                        self.opt.task_num,
                                        self.opt.netG_A_filter_list)
        self.netG_B = networks.define_G(self.opt.input_nc, self.opt.output_nc,
                                        self.opt.ngf, self.opt.netG,
                                        self.opt.norm, self.opt.dropout,
                                        self.opt.init_type, self.opt.init_gain,
                                        self.opt.task_num,
                                        self.opt.netG_B_filter_list)

        if opt.train:
            self.netD_A = networks.define_D(self.opt.input_nc, self.opt.ndf,
                                            self.opt.netD, self.opt.norm,
                                            self.opt.init_type,
                                            self.opt.init_gain)
            self.netD_B = networks.define_D(self.opt.input_nc, self.opt.ndf,
                                            self.opt.netD, self.opt.norm,
                                            self.opt.init_type,
                                            self.opt.init_gain)

            self.fake_A_pool = ImageBuffer(
                self.opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImageBuffer(
                self.opt.pool_size
            )  # create image buffer to store previously generated images

            self.criterionGAN = networks.GANLoss(self.opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=self.opt.lr,
                                                betas=(self.opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=self.opt.lr,
                                                betas=(self.opt.beta1, 0.999))

            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            self.schedulers = [
                networks.get_scheduler(optimizer, opt)
                for optimizer in self.optimizers
            ]
Ejemplo n.º 3
0
    def __init__(self, opt, _isTrain=False):
        self.initialize(opt)

        self.mode = opt.mode
        if opt.input == 'single_view':
            self.num_input = 3
        elif opt.input == 'two_view':
            self.num_input = 6
        elif opt.input == 'two_view_k':
            self.num_input = 7
        else:
            raise ValueError("Unknown input type %s" % opt.input)

        if self.mode == 'Ours_Bilinear':
            # print(
            #     '======================================  DIW NETWORK TRAIN FROM %s======================='
            #     % self.mode)

            new_model = hourglass.HourglassModel(self.num_input)

            # print(
            #     '===================Loading Pretrained Model OURS ==================================='
            # )

            if not _isTrain:
                if self.num_input == 7:
                    model_parameters = self.load_network(
                        new_model, 'G', 'best_depth_Ours_Bilinear_inc_7')
                elif self.num_input == 3:
                    model_parameters = self.load_network(
                        new_model, 'G', 'best_depth_Ours_Bilinear_inc_3')
                elif self.num_input == 6:
                    model_parameters = self.load_network(
                        new_model, 'G', 'best_depth_Ours_Bilinear_inc_6')
                else:
                    print('Something Wrong')
                    sys.exit()

                new_model.load_state_dict(model_parameters)

            # new_model = torch.nn.parallel.DataParallel(
            #     new_model.cuda(), device_ids=range(torch.cuda.device_count()))

            self.netG = new_model

        else:
            print('ONLY SUPPORT Ours_Bilinear')
            sys.exit()

        self.old_lr = opt.lr
        self.netG.train()

        if True:
            self.criterion_joint = networks.JointLoss(opt)
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(
                self.netG.parameters(), lr=opt.lr, betas=(0.9, 0.999))
            self.scheduler = networks.get_scheduler(self.optimizer_G, opt)
Ejemplo n.º 4
0
 def setup(self, opt, verbose=True):
     self.schedulers = [
         networks.get_scheduler(optimizer, opt)
         for optimizer in self.optimizers
     ]
     self.load_networks(verbose)
     if verbose:
         self.print_networks()
     self.add_mapping_hook()
Ejemplo n.º 5
0
    def setup(self, opt):
        if self.isTrain:
            self.schedulers = [
                networks.get_scheduler(optimizer, opt)
                for optimizer in self.optimizers
            ]

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)
        self.print_networks(opt.verbose)
Ejemplo n.º 6
0
    def setup(self, args, parser=None):
        if self.isTrain:
            self.schedulers = [
                networks.get_scheduler(optimizer, args)
                for optimizer in self.optimizers
            ]

        # if not self.isTrain or args.continue_train:
        if not self.isTrain:
            self.load_networks(args.epoch)
        self.print_networks(args.verbose)
Ejemplo n.º 7
0
    def setup(self, opt):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        if self.isTrain:
            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.continue_train:
            load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
            self.load_networks(load_suffix)
        self.print_networks(opt.verbose)
Ejemplo n.º 8
0
def test_scheduler():
    from options.gan_options import TrainGANOptions
    from models.networks import get_scheduler

    model = nn.Linear(10, 10)
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    opt = TrainGANOptions().parse('--gpu_ids -1 --benchmark debug', False,
                                  False, False)
    scheduler = get_scheduler(optim, opt)

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        scheduler.step()
        print('epoch: %d, lr: %f' % (epoch, optim.param_groups[0]['lr']))
Ejemplo n.º 9
0
    def setup(self, opt, verbose=True):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.load_networks(verbose=verbose)
        if self.isTrain:
            self.schedulers = [
                networks.get_scheduler(optimizer, opt)
                for optimizer in self.optimizers
            ]
        if verbose:
            self.print_networks()
Ejemplo n.º 10
0
    def setup(self):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """

        load_prefix = self.opt.load_flag if self.opt.load_flag else 'latest'
        epoch = self.load_networks(load_prefix)

        if self.isTrain:
            self.schedulers = [
                get_scheduler(optimizer, self.opt)
                for optimizer in self.optimizers
            ]
        self.print_networks(self.opt.print_network)
        return epoch
Ejemplo n.º 11
0
    def setup(self, opt, verbose=True):
        self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        self.load_networks(verbose)
        if verbose:
            self.print_networks()
        if self.opt.lambda_distill > 0:
            def get_activation(mem, name):
                def get_output_hook(module, input, output):
                    mem[name] = output

                return get_output_hook

            def add_hook(net, mem, mapping_layers):
                for n, m in net.named_modules():
                    if n in mapping_layers:
                        m.register_forward_hook(get_activation(mem, n))

            add_hook(self.netG_teacher, self.Tacts, self.mapping_layers)
            add_hook(self.netG_student, self.Sacts, self.mapping_layers)
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
Ejemplo n.º 13
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.opt = opt
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'G_GAN', 'G_L1', 'D', 'style', 'content', 'tv', 'hole', 'valid'
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.opt.show_flow:
            self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
        else:
            self.visual_names = [
                'real_input', 'fake_B', 'real_GTimg', 'mask_global',
                'output_comp'
            ]
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']

        #
        # self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
        #                               opt.which_model_netG, opt,  opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain)
        self.netG = PConvUNet().to(self.device)
        print(self.netG)
        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == 'vanilla':
                use_sigmoid = True  # only vanilla GAN using BCECriterion
            # don't use cGAN
            self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid,
                                          opt.use_spectral_norm_D,
                                          opt.init_type, self.gpu_ids,
                                          opt.init_gain)

        # add style extractor
        self.vgg16_extractor = util.VGG16FeatureExtractor().to(self.gpu_ids[0])
        self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor,
                                                     self.gpu_ids)

        if self.isTrain:
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(
                self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionL1_mask = networks.Discounted_L1(opt).to(
                self.device
            )  # make weights/buffers transfer to the correct device
            # VGG loss
            self.criterionL2_style_loss = torch.nn.MSELoss()
            self.criterionL2_content_loss = torch.nn.MSELoss()
            # TV loss
            self.tv_criterion = networks.TVLoss(self.opt.tv_weight)

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.opt.gan_type == 'wgan_gp':
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.9))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.9))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)
Ejemplo n.º 14
0
    dataloader = DataLoader(opt)
    model = RegresserModel(opt)

    if not opt.no_cuda:
        if len(opt.gpu_ids) > 1:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=opt.gpu_ids)
            net_dict = model.state_dict()
        else:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            net_dict = model.state_dict()

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.001)
    scheduler = get_scheduler(optimizer, opt)

    if opt.resume_path:
        if os.path.isfile(opt.resume_path):
            print("=> loading checkpoint '{}'".format(opt.resume_path))
            checkpoint = torch.load(opt.resume_path)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(opt.resume_path, checkpoint['epoch']))

    train(dataloader=dataloader,
          model=model,
          optimizer=optimizer,
          scheduler=scheduler,
          total_epochs=opt.nepoch,
Ejemplo n.º 15
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.target_weight = []
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C_sr = self.Tensor(nb, opt.output_nc, size, size)
        if opt.aux:
            self.A_aux = self.Tensor(nb, opt.input_nc, size, size)
            self.B_aux = self.Tensor(nb, opt.output_nc, size, size)
            self.C_aux = self.Tensor(nb, opt.output_nc, size, size)

        self.netE_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=2)

        mult = self.netE_A.get_mult()

        self.netE_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        64,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=3)

        self.net_D = networks.define_G(opt.input_nc,
                                       opt.output_nc,
                                       opt.ngf,
                                       'ResnetDecoder_my',
                                       opt.norm,
                                       not opt.no_dropout,
                                       opt.init_type,
                                       self.gpu_ids,
                                       opt=opt,
                                       mult=mult)

        mult = self.net_D.get_mult()

        self.net_Dc = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetDecoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult,
                                        n_upsampling=1)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        mult = self.net_Dc.get_mult()

        self.netG_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        #        self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #      set_eval(self.netG_A_running)
        #     accumulate(self.netG_A_running, self.netG_A, 0)
        #        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #    self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc,
        #                                   opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #  set_eval(self.netG_B_running)
        # accumulate(self.netG_B_running, self.netG_B, 0)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
#         self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
#                                          opt.which_model_netD,
#                                        opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
        print('---------- Networks initialized -------------')
        #        networks.print_network(self.netG_B, opt, (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(self.netE_C, opt,
                               (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(
            self.net_D, opt, (opt.ngf * 4, opt.fineSize / 4, opt.fineSize / 4))
        networks.print_network(self.net_Dc, opt,
                               (opt.ngf, opt.CfineSize / 2, opt.CfineSize / 2))
        # networks.print_network(self.netG_B, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            # networks.print_network(self.netD_B, opt)
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print('Loaded model')
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netG_A_running, 'G_A', which_epoch)
                self.load_network(self.netG_B_running, 'G_B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain and opt.load_path != '':
            print('Loaded model from load_path')
            which_epoch = opt.which_epoch
            load_network_with_path(self.netG_A,
                                   'G_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netG_B,
                                   'G_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_A,
                                   'D_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_B,
                                   'D_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.fake_C_pool = ImagePool(opt.pool_size)
            # define loss functions
            if len(self.target_weight) == opt.num_D:
                print(self.target_weight)
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    target_weight=self.target_weight,
                    gan=opt.gan)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    gan=opt.gan)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionColor = networks.ColorLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.netG_A.parameters(), self.net_Dc.parameters(),
                self.netG_C.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                 lr=opt.lr,
                                                 betas=(opt.beta1, 0.999))
            self.optimizer_G_A_sr = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                     lr=opt.lr,
                                                     betas=(opt.beta1, 0.999))
            self.optimizer_AE_sr = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.netG_A.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            #       self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_AE)
            # self.optimizers.append(self.optimizer_G_A_sr)
            self.optimizers.append(self.optimizer_AE_sr)
            self.optimizers.append(self.optimizer_D_A)
            #   self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))
Ejemplo n.º 16
0
    def __init__(self, opt, data=None):
        super(Skipganomaly, self).__init__(opt, data)
        ##

        # -- Misc attributes
        self.add_noise = True
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks.
        self.netg = define_G(self.opt,
                             norm='batch',
                             use_dropout=False,
                             init_type='normal')
        self.netd = define_D(self.opt,
                             norm='batch',
                             use_sigmoid=False,
                             init_type='normal')

        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")

        if self.opt.verbose:
            print(self.netg)
            print(self.netd)

        ##
        # Loss Functions
        self.l_adv = nn.BCELoss()
        self.l_con = nn.L1Loss()
        self.l_lat = l2_loss

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.noise = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = torch.ones(size=(self.opt.batchsize, ),
                                     dtype=torch.float32,
                                     device=self.device)
        self.fake_label = torch.zeros(size=(self.opt.batchsize, ),
                                      dtype=torch.float32,
                                      device=self.device)

        ##
        # Setup optimizer
        if self.opt.isTrain:
            self.netg.train()
            self.netd.train()
            self.optimizers = []
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_d)
            self.optimizers.append(self.optimizer_g)
            self.schedulers = [
                get_scheduler(optimizer, opt) for optimizer in self.optimizers
            ]
Ejemplo n.º 17
0
 def set_scheduler(self, opts, ep):
     assert (self.isTrain)
     self.schedulers = [
         networks.get_scheduler(optimizer, opts, ep)
         for optimizer in self.optimizers
     ]
Ejemplo n.º 18
0
    def __init__(self, opts):

        BaseModel.__init__(self, opts)

        lr = self.opt.lr

        self.model_names = ['G_A', 'G_B']
        self.loss_names = [
            'd_total', 'g_total', 'g_rec_x_a', 'g_rec_x_b', 'g_rec_s_a',
            'g_rec_s_b', 'g_rec_c_a', 'g_rec_c_b', 'g_adv_a', 'g_adv_b'
        ]
        self.visual_names = []
        # Initiate the networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
            self.netD_A = init_net(MsImageDis(self.opt.input_dim_a,
                                              self.opt.dis),
                                   init_type=self.opt.init,
                                   gpu_ids=self.gpu_ids)
            self.netD_B = init_net(MsImageDis(self.opt.input_dim_b,
                                              self.opt.dis),
                                   init_type=self.opt.init,
                                   gpu_ids=self.gpu_ids)
        self.netG_A = init_net(AdaINGen(self.opt.input_dim_a, self.opt.gen),
                               init_type=self.opt.init,
                               gpu_ids=self.gpu_ids)
        self.netG_B = init_net(AdaINGen(self.opt.input_dim_b, self.opt.gen),
                               init_type=self.opt.init,
                               gpu_ids=self.gpu_ids)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = self.opt.gen['style_dim']

        # fix the noise used in sampling
        display_size = self.opt.display_size
        self.s_a_fixed = torch.randn(display_size, self.style_dim, 1,
                                     1).to(self.device)
        self.s_b_fixed = torch.randn(display_size, self.style_dim, 1,
                                     1).to(self.device)

        if self.isTrain:
            # Setup the optimizers

            d_params = list(self.netD_A.parameters()) + list(
                self.netD_B.parameters())
            g_params = list(self.netG_A.parameters()) + list(
                self.netG_B.parameters())

            self.optimizer_D = torch.optim.Adam(
                [p for p in d_params if p.requires_grad],
                lr=lr,
                betas=(self.opt.beta1, self.opt.beta2),
                weight_decay=self.opt.weight_decay)
            self.optimizer_G = torch.optim.Adam(
                [p for p in g_params if p.requires_grad],
                lr=lr,
                betas=(self.opt.beta1, self.opt.beta2),
                weight_decay=self.opt.weight_decay)
            self.optimizer_names = ['optimizer_D', 'optimizer_G']
            self.optimizers.append(self.optimizer_D)
            self.optimizers.append(self.optimizer_G)

            self.scheduler_D = get_scheduler(self.optimizer_D, self.opt)
            self.scheduler_G = get_scheduler(self.optimizer_G, self.opt)
            self.schedulers = [self.scheduler_D, self.scheduler_G]

        # Load VGG model if needed
        if (self.opt.vgg_w is not None) and self.opt.vgg_w > 0:
            self.vgg = load_vgg16(self.opt.vgg_model_path + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Ejemplo n.º 19
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.count = 0
        input_depth = opt.input_nc
        output_depth = opt.output_nc
        self.net_shared = skip(input_depth,
                               num_channels_down=[64, 128, 256, 256, 256],
                               num_channels_up=[64, 128, 256, 256, 256],
                               num_channels_skip=[4, 4, 4, 4, 4],
                               upsample_mode=[
                                   'nearest', 'nearest', 'bilinear',
                                   'bilinear', 'bilinear'
                               ],
                               need_sigmoid=True,
                               need_bias=True,
                               pad='reflection')
        self.netDec_a = ResNet_decoders(opt.ngf, output_depth)
        self.netDec_b = ResNet_decoders(opt.ngf, output_depth)

        self.net_input = self.get_noise(input_depth, 'noise',
                                        (self.opt.fineSize, self.opt.fineSize))
        self.net_input_saved = self.net_input.detach().clone()
        self.noise = self.net_input.detach().clone()

        use_sigmoid = opt.no_lsgan
        self.netD_b = networks.define_D(opt.output_nc, opt.ndf,
                                        opt.which_model_netD, opt.n_layers_D,
                                        opt.norm, use_sigmoid, opt.init_type,
                                        self.gpu_ids)

        if not opt.dont_load_pretrained_autoencoder:
            which_epoch = opt.which_epoch
            self.load_network(self.netDec_b, 'Dec_b', which_epoch)
            self.load_network(self.net_shared, 'Net_shared', which_epoch)
            self.load_network(self.netD_b, 'D', which_epoch)

        if len(self.gpu_ids) > 0:
            dtype = torch.cuda.FloatTensor
            self.net_input = self.net_input.type(dtype).detach()
            self.net_shared = self.net_shared.type(dtype)
            self.netDec_a = self.netDec_a.type(dtype)
            self.netDec_b = self.netDec_b.type(dtype)
            self.netD_b = self.netD_b.type(dtype)

        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)

        # define loss functions
        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                             tensor=self.Tensor)
        self.mse = torch.nn.MSELoss()

        # initialize optimizers
        self.optimizer_Net = torch.optim.Adam(
            itertools.chain(self.net_shared.parameters(),
                            self.netDec_a.parameters()),
            lr=0.007,
            betas=(opt.beta1,
                   0.999))  # skip 0.01   # OST 0.001 # skip large 0.007
        self.optimizer_Dec_b = torch.optim.Adam(
            self.netDec_b.parameters(), lr=0.000007,
            betas=(opt.beta1,
                   0.999))  # OST 0.000007 skip 0.00002 skip large 0.000007
        self.optimizer_D_b = torch.optim.Adam(self.netD_b.parameters(),
                                              lr=0.0002,
                                              betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.schedulers = []
        self.optimizers.append(self.optimizer_Net)
        self.optimizers.append(self.optimizer_Dec_b)
        self.optimizers.append(self.optimizer_D_b)
        for optimizer in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optimizer, opt))
Ejemplo n.º 20
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.opt = opt
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.opt.show_flow:
            self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
        else:
            self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']

        # batchsize should be 1 for mask_global
        self.mask_global = torch.ByteTensor(1, 1, \
                                 opt.fineSize, opt.fineSize)

        # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)
        self.mask_global.zero_()
        self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1

        self.mask_type = opt.mask_type
        self.gMask_opts = {}

        self.wgan_gp = False
        # added for wgan-gp
        if opt.gan_type == 'wgan_gp':
            self.gp_lambda = opt.gp_lambda
            self.ncritic = opt.ncritic
            self.wgan_gp = True

        if len(opt.gpu_ids) > 0:
            self.use_gpu = True
            self.mask_global = self.mask_global.to(self.device)

        # load/define networks
        # self.ng_innerCos_list is the constraint list in netG inner layers.
        # self.ng_mask_list is the mask list constructing shift operation.
        if opt.add_mask2input:
            input_nc = opt.input_nc + 1
        else:
            input_nc = opt.input_nc

        self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(
            input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt,
            self.mask_global, opt.norm, opt.use_dropout,
            opt.use_spectral_norm_G, opt.init_type, self.gpu_ids,
            opt.init_gain)  # add opt, we need opt.shift_sz and other stuffs
        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == 'vanilla':
                use_sigmoid = True  # only vanilla GAN using BCECriterion
            # don't use cGAN
            self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid,
                                          opt.use_spectral_norm_D,
                                          opt.init_type, self.gpu_ids,
                                          opt.init_gain)

        if self.isTrain:
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(
                self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionL1_mask = util.Discounted_L1(opt).to(
                self.device
            )  # make weights/buffers transfer to the correct device

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.wgan_gp:
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)
Ejemplo n.º 21
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.opt = opt
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.opt.show_flow:
            self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
        else:
            self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']

        # batchsize should be 1 for mask_global
        self.mask_global = torch.zeros((self.opt.batchSize, 1, \
                                 opt.fineSize, opt.fineSize), dtype=torch.bool)

        # Here we need to set an artificial mask_global(center hole is ok.)
        self.mask_global.zero_()
        # self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
        #                         int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
        self.mask_global[:, :, int(self.opt.fineSize * 3 / 8) + self.opt.overlap: int(self.opt.fineSize / 2) + int(self.opt.fineSize / 8) - self.opt.overlap, \
                                int(self.opt.fineSize * 3 / 8) + self.opt.overlap: int(self.opt.fineSize / 2) + int(self.opt.fineSize / 8) - self.opt.overlap] = 1
        if len(opt.gpu_ids) > 0:
            self.mask_global = self.mask_global.to(self.device)

        # load/define networks
        # self.ng_innerCos_list is the guidance loss list in netG inner layers.
        # self.ng_shift_list is the mask list constructing shift operation.
        if opt.add_mask2input:
            input_nc = opt.input_nc + 1
        else:
            input_nc = opt.input_nc

        self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(
            input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt,
            self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type,
            self.gpu_ids, opt.init_gain)

        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == 'vanilla':
                use_sigmoid = True  # only vanilla GAN using BCECriterion
            # don't use cGAN
            self.netD = networks.define_D(1, opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.use_spectral_norm_D,
                                          opt.init_type, self.gpu_ids,
                                          opt.init_gain)

        # add style extractor
        self.vgg16_extractor = util.VGG16FeatureExtractor().to(self.gpu_ids[0])
        self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor,
                                                     self.gpu_ids)

        if self.isTrain:
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(
                self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionL1_mask = networks.Discounted_L1(opt).to(
                self.device
            )  # make weights/buffers transfer to the correct device
            # VGG loss
            self.criterionL2_style_loss = torch.nn.MSELoss()
            self.criterionL2_content_loss = torch.nn.MSELoss()
            # TV loss
            self.tv_criterion = networks.TVLoss(self.opt.tv_weight)

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.opt.gan_type == 'wgan_gp':
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.9))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.9))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)
Ejemplo n.º 22
0
    def __init__(self, args):

        self.gpu_ids=[0]
        self.isTrain = True
        
        self.checkpoints_dir = './checkpoints'
        self.which_epoch = 'latest' # which epoch to load? set to latest to use latest cached model
        self.args = args
        # self.name = 'G_GAN_%s_lambdar_%s_lambdas_%s_alpha_%s' % (self.args.lambda_d, self.args.lambda_r, self.args.lambda_s, self.args.alpha)
        self.name = 'Res_convolution_Gram'
        expr_dir = os.path.join(self.checkpoints_dir, self.name)
        if not os.path.exists(expr_dir):
            os.makedirs(expr_dir)

        self.save_dir = os.path.join(self.checkpoints_dir, self.name)
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_style', 'G_content']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        self.model_names = ['C', 'E', 'G']
        self.visual_names = ['A', 'B', 'C', 'R']
        self.input_nc = 3
        self.output_nc = 3
        self.ndf = 64 #number of filters in the first layer of discriminator
        self.ngf = 64

        use_sigmoid = False
        # define networks

        self.netCA = networks.define_channel_attention(self.gpu_ids)

        self.netKA = networks.define_kernel_attention(self.gpu_ids)

        self.netSA = networks.define_spatial_attention(self.gpu_ids)

        self.netC = networks.define_Convolution(self.gpu_ids)

        self.netKC = networks.define_K_Convolution(self.gpu_ids)

        self.netVGG = networks.define_VGG()

        self.netE = networks.define_E(self.input_nc, self.ngf, self.gpu_ids)

        self.netG = networks.define_G(self.input_nc, self.output_nc, self.ngf, self.gpu_ids)


        self.criterionMSE = torch.nn.MSELoss()

        self.criterionL1 = torch.nn.L1Loss()


        # initialize optimizers

        self.schedulers = []
        self.optimizers = []

        self.optimizer_CA = torch.optim.Adam(self.netCA.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_KA = torch.optim.Adam(self.netKA.parameters(),
                                             lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_SA = torch.optim.Adam(self.netSA.parameters(),
                                             lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_C = torch.optim.Adam(self.netC.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_KC = torch.optim.Adam(self.netKC.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))

        self.optimizers.append(self.optimizer_CA)
        self.optimizers.append(self.optimizer_KA)
        self.optimizers.append(self.optimizer_SA)
        self.optimizers.append(self.optimizer_C)
        self.optimizers.append(self.optimizer_KC)
        self.optimizers.append(self.optimizer_E)
        self.optimizers.append(self.optimizer_G)

        for optimizer in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optimizer, lr_policy='lambda', epoch_count=1, niter=100, niter_decay=100, lr_decay_iters=50))

        if not self.isTrain or args.continue_train:
            self.load_networks(self.which_epoch)

        self.print_networks()