Exemplo n.º 1
0
    def set_input(self, input):
        # good practise
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        # input_A is Tensor, and self.input_A has been transfer to cuda(self.Tensor are cuda tensor if gpus>0)
        self.input_A.resize_(input_A.size()).copy_(input_A)  
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        # Add mask to input_A
        # When the mask is random, or the mask is not fixed, we all need to create_gMask
        if self.fixed_mask:
            if self.opt.mask_type == 'center':
                self.mask_global.zero_()
                self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                    int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
            elif self.opt.mask_type == 'random':
                self.mask_global = util.create_gMask(self.gMask_opts).type_as(self.mask_global)
            else:
                raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type)
        else:
            self.mask_global = util.create_gMask(self.gMask_opts).type_as(self.mask_global)
        
        self.set_latent_mask(self.mask_global, 3, self.opt.threshold)

        # keep consistent with preprocessing, the vaule is a little different from torch version
        # However, it only makes little difference.
        self.input_A.narrow(1,0,1).masked_fill_(self.mask_global, 2*123.0/255.0 - 1.0)
        self.input_A.narrow(1,1,1).masked_fill_(self.mask_global, 2*104.0/255.0 - 1.0)
        self.input_A.narrow(1,2,1).masked_fill_(self.mask_global, 2*117.0/255.0 - 1.0)
Exemplo n.º 2
0
    def set_input(self, input):
        input_A = input['A']
        input_B = input['B']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths']

        # Add mask to input_A
        # When the mask is random, or the mask is not fixed, we all need to create_gMask
        if self.fixed_mask:
            if self.opt.mask_type == 'center':
                self.mask_global.zero_()
                self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                    int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
            elif self.opt.mask_type == 'random':
                self.mask_global = util.create_gMask(self.gMask_opts).type_as(
                    self.mask_global)
            else:
                raise ValueError("Mask_type [%s] not recognized." %
                                 self.opt.mask_type)
        else:
            self.mask_global = util.create_gMask(self.gMask_opts).type_as(
                self.mask_global)

        self.set_latent_mask(self.mask_global, 3, self.opt.threshold)

        self.input_A.narrow(1, 0, 1).masked_fill_(self.mask_global,
                                                  2 * 123.0 / 255.0 - 1.0)
        self.input_A.narrow(1, 1, 1).masked_fill_(self.mask_global,
                                                  2 * 104.0 / 255.0 - 1.0)
        self.input_A.narrow(1, 2, 1).masked_fill_(self.mask_global,
                                                  2 * 117.0 / 255.0 - 1.0)
Exemplo n.º 3
0
    def set_input(self, input):
        input_A = input['A']
        input_B = input['B']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths']

        # Add mask to input_A
        # When the mask is random, or the mask is not fixed, we all need to create_gMask
        if self.fixed_mask:
            if self.opt.mask_type == 'center':
                self.mask_global.zero_()
                self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                    int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
            elif 'text' in self.opt.mask_type:
                input_A_char_mask = input['A_char_mask']
                input_A_txt_bb_mask = input['A_txt_bb_mask']
                if 'char' in self.opt.mask_type:
                    chosen_mask = input_A_char_mask
                elif 'txt_bb' in self.opt.mask_type:
                    chosen_mask = input_A_txt_bb_mask
                else:
                    raise ValueError('Unknown opt.mask_type')

                self.mask_global = chosen_mask.type_as(self.mask_global)
                # print(self.mask_global)

            elif self.opt.mask_type == 'random':
                self.mask_global = util.create_gMask(self.gMask_opts).type_as(
                    self.mask_global)
            else:
                raise ValueError("Mask_type [%s] not recognized." %
                                 self.opt.mask_type)
        else:
            self.mask_global = util.create_gMask(self.gMask_opts).type_as(
                self.mask_global)

        self.set_latent_mask(self.mask_global, 3, self.opt.threshold)

        # print(input_A.size())
        img_avg_r = 123.0
        img_avg_g = 117.0  # original code swapped g and b
        img_avg_b = 104.0

        # img_avg_b = 117.0 # original code swapped g and b
        # img_avg_g = 104.0

        self.input_A.narrow(1, 0, 1).masked_fill_(self.mask_global,
                                                  2 * img_avg_r / 255.0 - 1.0)
        self.input_A.narrow(1, 1, 1).masked_fill_(self.mask_global,
                                                  2 * img_avg_g / 255.0 - 1.0)
        self.input_A.narrow(1, 2, 1).masked_fill_(self.mask_global,
                                                  2 * img_avg_b / 255.0 - 1.0)
Exemplo n.º 4
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.opt = opt
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # batchsize should be 1 for mask_global
        self.mask_global = torch.ByteTensor(1, 1, \
                                 opt.fineSize, opt.fineSize)

        # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)
        self.mask_global.zero_()
        self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1

        self.mask_type = opt.mask_type
        self.gMask_opts = {}
        # NOTE by JH : 'text' then don't generate mask but use loaded mask.
        self.fixed_mask = opt.fixed_mask if opt.mask_type == 'center' or 'text' in opt.mask_type else 0
        if opt.mask_type == 'center':
            assert opt.fixed_mask == 1, "Center mask must be fixed mask!"

        if 'text' in opt.mask_type:
            assert opt.fixed_mask == 1, "Center mask must be fixed mask!"

        if self.mask_type == 'random':
            res = 0.06  # the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
            density = 0.25
            MAX_SIZE = 10000
            maxPartition = 30
            low_pattern = torch.rand(1, 1, int(res * MAX_SIZE),
                                     int(res * MAX_SIZE)).mul(255)
            pattern = F.upsample(low_pattern, (MAX_SIZE, MAX_SIZE),
                                 mode='bilinear').data
            low_pattern = None
            pattern.div_(255)
            pattern = torch.lt(pattern, density).byte()  # 25% 1s and 75% 0s
            pattern = torch.squeeze(pattern).byte()
            print('...Random pattern generated')
            self.gMask_opts['pattern'] = pattern
            self.gMask_opts['MAX_SIZE'] = MAX_SIZE
            self.gMask_opts['fineSize'] = opt.fineSize
            self.gMask_opts['maxPartition'] = maxPartition
            self.gMask_opts['mask_global'] = self.mask_global
            self.mask_global = util.create_gMask(
                self.gMask_opts)  # create an initial random mask.

        self.wgan_gp = False
        # added for wgan-gp
        if opt.gan_type == 'wgan_gp':
            self.gp_lambda = opt.gp_lambda
            self.ncritic = opt.ncritic
            self.wgan_gp = True

        if len(opt.gpu_ids) > 0:
            self.use_gpu = True
            self.mask_global = self.mask_global.cuda()

        # load/define networks
        # self.ng_innerCos_list is the constraint list in netG inner layers.
        # self.ng_mask_list is the mask list constructing shift operation.
        self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(
            opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt,
            self.mask_global, opt.norm, opt.use_dropout, opt.init_type,
            self.gpu_ids,
            opt.init_gain)  # add opt, we need opt.shift_sz and other stuffs
        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == 'vanilla':
                use_sigmoid = True  # only vanilla GAN using BCECriterion
            # don't use cGAN
            self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid, opt.init_type,
                                          self.gpu_ids, opt.init_gain)
        if not self.isTrain or opt.continue_train:
            print('Loading pre-trained network!')
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.wgan_gp:
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG)
            if self.isTrain:
                networks.print_network(self.netD)
            print('-----------------------------------------------')
Exemplo n.º 5
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.opt = opt
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']


        # batchsize should be 1 for mask_global
        self.mask_global = torch.ByteTensor(1, 1, \
                                 opt.fineSize, opt.fineSize)

        # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)
        self.mask_global.zero_()
        self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1

        self.mask_type = opt.mask_type
        self.gMask_opts = {}
        self.fixed_mask = opt.fixed_mask if opt.mask_type == 'center' else 0
        if opt.mask_type == 'center':
            assert opt.fixed_mask == 1, "Center mask must be fixed mask!"

        if self.mask_type == 'random':
            res = 0.06 # the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
            density = 0.25
            MAX_SIZE = 10000
            maxPartition = 30
            low_pattern = torch.rand(1, 1, int(res*MAX_SIZE), int(res*MAX_SIZE)).mul(255)
            pattern = F.functional.interpolate(low_pattern, (MAX_SIZE, MAX_SIZE), mode='bilinear').detach()
            low_pattern = None
            pattern.div_(255)
            pattern = torch.lt(pattern,density).byte()  # 25% 1s and 75% 0s
            pattern = torch.squeeze(pattern).byte()
            print('...Random pattern generated')
            self.gMask_opts['pattern'] = pattern
            self.gMask_opts['MAX_SIZE'] = MAX_SIZE
            self.gMask_opts['fineSize'] = opt.fineSize
            self.gMask_opts['maxPartition'] = maxPartition
            self.gMask_opts['mask_global'] = self.mask_global
            self.mask_global = util.create_gMask(self.gMask_opts) # create an initial random mask.


        self.wgan_gp = False
        # added for wgan-gp
        if opt.gan_type == 'wgan_gp':
            self.gp_lambda = opt.gp_lambda
            self.ncritic = opt.ncritic
            self.wgan_gp = True


        if len(opt.gpu_ids) > 0:
            self.use_gpu = True
            self.mask_global = self.mask_global.to(self.device)

        # load/define networks
        # self.ng_innerCos_list is the constraint list in netG inner layers.
        # self.ng_mask_list is the mask list constructing shift operation.
        self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_dropout, opt.init_type, self.gpu_ids, opt.init_gain) # add opt, we need opt.shift_sz and other stuffs
        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == 'vanilla':
                use_sigmoid = True  # only vanilla GAN using BCECriterion
            # don't use cGAN
            self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt.init_gain)

        if self.isTrain:
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.wgan_gp:
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)