def initialize(self, opt): assert (not opt.isTrain) BaseModel.initialize(self, opt) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt, opt.norm, opt.init_type, self.gpu_ids) which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) print('---------- Networks initialized -------------') networks.print_network(self.netG) print('-----------------------------------------------')
def initialize(self, opt): BaseModel.initialize(self, opt) self.opt = opt self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv'] # specify the images you want to save/display. The program will call base_model.get_current_visuals if self.opt.show_flow: self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs'] else: self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # batchsize should be 1 for mask_global self.mask_global = torch.zeros((self.opt.batchSize, 1, \ opt.fineSize, opt.fineSize), dtype=torch.bool) # Here we need to set an artificial mask_global(center hole is ok.) self.mask_global.zero_() # self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ # int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 self.mask_global[:, :, int(self.opt.fineSize * 3 / 8) + self.opt.overlap: int(self.opt.fineSize / 2) + int(self.opt.fineSize / 8) - self.opt.overlap, \ int(self.opt.fineSize * 3 / 8) + self.opt.overlap: int(self.opt.fineSize / 2) + int(self.opt.fineSize / 8) - self.opt.overlap] = 1 if len(opt.gpu_ids) > 0: self.mask_global = self.mask_global.to(self.device) # load/define networks # self.ng_innerCos_list is the guidance loss list in netG inner layers. # self.ng_shift_list is the mask list constructing shift operation. if opt.add_mask2input: input_nc = opt.input_nc + 1 else: input_nc = opt.input_nc self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G( input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain) if self.isTrain: use_sigmoid = False if opt.gan_type == 'vanilla': use_sigmoid = True # only vanilla GAN using BCECriterion # don't use cGAN self.netD = networks.define_D(1, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain) # add style extractor self.vgg16_extractor = util.VGG16FeatureExtractor().to(self.gpu_ids[0]) self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids) if self.isTrain: self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to( self.device) self.criterionL1 = torch.nn.L1Loss() self.criterionL1_mask = networks.Discounted_L1(opt).to( self.device ) # make weights/buffers transfer to the correct device # VGG loss self.criterionL2_style_loss = torch.nn.MSELoss() self.criterionL2_content_loss = torch.nn.MSELoss() # TV loss self.tv_criterion = networks.TVLoss(self.opt.tv_weight) # initialize optimizers self.schedulers = [] self.optimizers = [] if self.opt.gan_type == 'wgan_gp': opt.beta1 = 0 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) else: self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose)
def initialize(self, opt): BaseModel.initialize(self, opt) self.opt = opt self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D'] # specify the images you want to save/display. The program will call base_model.get_current_visuals if self.opt.show_flow: self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs'] else: self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # batchsize should be 1 for mask_global self.mask_global = torch.ByteTensor(1, 1, \ opt.fineSize, opt.fineSize) # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.) self.mask_global.zero_() self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 self.mask_type = opt.mask_type self.gMask_opts = {} self.wgan_gp = False # added for wgan-gp if opt.gan_type == 'wgan_gp': self.gp_lambda = opt.gp_lambda self.ncritic = opt.ncritic self.wgan_gp = True if len(opt.gpu_ids) > 0: self.use_gpu = True self.mask_global = self.mask_global.to(self.device) # load/define networks # self.ng_innerCos_list is the constraint list in netG inner layers. # self.ng_mask_list is the mask list constructing shift operation. if opt.add_mask2input: input_nc = opt.input_nc + 1 else: input_nc = opt.input_nc self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G( input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_dropout, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain) # add opt, we need opt.shift_sz and other stuffs if self.isTrain: use_sigmoid = False if opt.gan_type == 'vanilla': use_sigmoid = True # only vanilla GAN using BCECriterion # don't use cGAN self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain) if self.isTrain: self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to( self.device) self.criterionL1 = torch.nn.L1Loss() self.criterionL1_mask = util.Discounted_L1(opt).to( self.device ) # make weights/buffers transfer to the correct device # initialize optimizers self.schedulers = [] self.optimizers = [] if self.wgan_gp: opt.beta1 = 0 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) else: self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose)