class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features ##### define networks # Generator network netG_input_nc = opt.label_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = opt.label_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = [ 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake' ] # initialize optimizers # optimizer G if opt.niter_fix_global > 0: print( '------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params': [value], 'lr': opt.lr}] else: params += [{'params': [value], 'lr': 0.0}] else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input( label, inst, image, feat) # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward( torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i]) - 1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [[ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image] def inference(self, label, inst): # Encode Inputs input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) # Fake Generation if self.use_features: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path).item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == i).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num + 1)) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 idx = (inst == i).nonzero() num = idx.size()[0] idx = idx[num // 2, :] val = np.zeros((1, feat_num + 1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class FTAEModel(BaseModel): def name(self): return 'FTAEModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.yaw = Variable(torch.Tensor([np.pi / 4.]).cuda(opt.gpu_ids[0], async=True), requires_grad=False) # load/define networks self.netG = FTAE( opt.input_nc, opt.ngf, n_layers=int(np.log2(opt.fineSize)), upsample=opt.upsample, norm_layer=networks.get_norm_layer(norm_type=opt.norm), nl_layer=networks.get_non_linearity(layer_type='lrelu'), gpu_ids=opt.gpu_ids, nz=opt.nz) if len(opt.gpu_ids) > 0: self.netG.cuda(opt.gpu_ids[0]) networks.init_weights(self.netG, init_type="normal") if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() self.criterionTV = networks.TVLoss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam( itertools.chain(self.netG.parameters()), #, [self.yaw] lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) grid = np.zeros((opt.fineSize, opt.fineSize, 2)) for i in range(grid.shape[0]): for j in range(grid.shape[1]): grid[i, j, 0] = j grid[i, j, 1] = i grid /= (opt.fineSize / 2) grid -= 1 self.grid = torch.from_numpy( grid).cuda().float() #Variable(torch.from_numpy(grid)) self.grid = self.grid.view(1, self.grid.size(0), self.grid.size(1), self.grid.size(2)).expand( opt.batchSize, opt.fineSize, opt.fineSize, 2) self.grid = Variable(self.grid) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] # input_A = input['B'] # input_B = flip(input_A,3) if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] if self.opt.dataset_mode == 'aligned_with_C': input_C = input['C'] if len(self.gpu_ids) > 0: self.input_C = input_C.cuda(self.gpu_ids[0], async=True) # self.mask = torch.sum(self.input_B, dim=1) self.mask = (self.mask < 3.0).unsqueeze(1) self.mask = self.mask.expand(self.input_B.size(0), 2, self.input_B.size(2), self.input_B.size(3)) # self.mask0 = torch.sum(self.input_A, dim=1) self.mask0 = (self.mask0 < 3.0).unsqueeze(1) self.mask0 = self.mask0.expand(self.input_B.size(0), 2, self.input_B.size(2), self.input_B.size(3)) def forward(self): add_grid = self.opt.add_grid rectified = self.opt.rectified self.real_A = Variable(self.input_A) if self.opt.dataset_mode == 'aligned_with_C': self.real_C = Variable(self.input_C) + self.grid self.fake_B_flow, _ = self.netG(self.real_A, self.yaw) self.fake_B_flow_converted = convert_flow(self.fake_B_flow, self.grid, add_grid, rectified) self.fake_B = torch.nn.functional.grid_sample( self.real_A, self.fake_B_flow_converted) self.real_B = Variable(self.input_B) self.fake_B_0_flow, _ = self.netG( self.real_A, Variable(torch.Tensor([0]).cuda(self.gpu_ids[0], async=True))) self.fake_B_flow_converted0 = convert_flow(self.fake_B_0_flow, self.grid, add_grid, rectified) self.fake_B_0 = torch.nn.functional.grid_sample( self.real_A, self.fake_B_flow_converted0) self.fake_B_18_flow, _ = self.netG( self.real_A, Variable( torch.Tensor([np.pi / 8.]).cuda(self.gpu_ids[0], async=True))) self.fake_B_18 = torch.nn.functional.grid_sample( self.real_A, convert_flow(self.fake_B_18_flow, self.grid, add_grid, rectified)) # no backprop gradients def test(self): add_grid = self.opt.add_grid rectified = self.opt.rectified self.real_A = Variable(self.input_A, volatile=True) self.real_B = Variable(self.input_B, volatile=True) self.fake_B_list = [] for i in range(10): fake_B_flow, z = self.netG( self.real_A, Variable( torch.Tensor([i / 9. * np.pi / 4.]).cuda(self.gpu_ids[0], async=True))) fake_B = torch.nn.functional.grid_sample( self.real_A, convert_flow(fake_B_flow, self.grid, add_grid, rectified)) self.fake_B_list.append(fake_B) # np.save(os.path.join("./results/features", os.path.basename(self.image_paths[0]) ), z.data.cpu().numpy()) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN( pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.opt.lambda_gan * self.criterionGAN( pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN( pred_fake, True) # Total variation loss self.loss_TV = self.criterionTV(self.fake_B_flow) * self.opt.lambda_tv self.loss_TV_2 = self.criterionTV( self.fake_B_0_flow) * self.opt.lambda_tv if self.opt.lambda_flow > 0: self.loss_G_flow = self.criterionL1( self.fake_B_flow_converted.permute(0, 3, 1, 2)[self.mask], self.real_C.permute(0, 3, 1, 2)[self.mask]) * self.opt.lambda_flow else: self.loss_G_flow = 0. * self.loss_TV if self.opt.lambda_flow0 > 0: self.loss_G_flow0 = self.criterionL1( self.fake_B_flow_converted.permute(0, 3, 1, 2)[self.mask0], self.grid.permute(0, 3, 1, 2)[self.mask0]) * self.opt.lambda_flow else: self.loss_G_flow0 = 0. * self.loss_TV # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G_L1_2 = self.criterionL1(self.fake_B_0, self.real_A) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_L1_2 \ + self.loss_TV + self.loss_TV_2 + self.loss_G_flow + self.loss_G_flow0 self.loss_G.backward() def optimize_parameters(self): self.forward() # # self.optimizer_D.zero_grad() # self.backward_D() # self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('G_L1_2', self.loss_G_L1_2.data[0]), ('F_L1', self.loss_G_flow.data[0]), ('F_L10', self.loss_G_flow0.data[0]), ('TV', self.loss_TV.data[0]), ('TV2', self.loss_TV_2.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]), ('Yaw', self.yaw.data[0])]) def get_current_visuals(self): if not self.opt.isTrain: return self.get_current_visuals_test() real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) fake_B_0 = util.tensor2im(self.fake_B_0.data) fake_B_18 = util.tensor2im(self.fake_B_18.data) flow = util.tensor2im( self.fake_B_flow_converted.permute(0, 3, 1, 2).data) flow0 = util.tensor2im( self.fake_B_flow_converted0.permute(0, 3, 1, 2).data) if self.opt.dataset_mode == 'aligned_with_C': real_flow = util.tensor2im(self.real_C.permute(0, 3, 1, 2).data) else: real_flow = util.tensor2im( self.fake_B_flow_converted.permute(0, 3, 1, 2).data) return OrderedDict([('real_A', real_A), ('fake_B_36', fake_B), ('real_B', real_B), ('fake_B_0', fake_B_0), ('fake_B_18', fake_B_18), ('flow', flow), ('flow0', flow0), ('real_flow', real_flow)]) def get_current_visuals_test(self): real_A = util.tensor2im(self.real_A.data) real_B = util.tensor2im(self.real_B.data) visual_list = OrderedDict([('real_A', real_A)]) for idx, fake_B_var in enumerate(self.fake_B_list): visual_list['%d' % idx] = util.tensor2im(fake_B_var.data) visual_list['real_B'] = real_B return visual_list def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)
class Pix2PixHDModel_Mapping(BaseModel): def name(self): return "Pix2PixHDModel_Mapping" def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2): flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2) def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2): return [ l for (l, f) in zip((g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags) if f ] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != "none" or not opt.isTrain: torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ##### define networks # Generator network netG_input_nc = input_nc self.netG_A = networks.GlobalGenerator_DCDCv2( netG_input_nc, opt.output_nc, opt.ngf, opt.k_size, opt.n_downsample_global, networks.get_norm_layer(norm_type=opt.norm), opt=opt, ) self.netG_B = networks.GlobalGenerator_DCDCv2( netG_input_nc, opt.output_nc, opt.ngf, opt.k_size, opt.n_downsample_global, networks.get_norm_layer(norm_type=opt.norm), opt=opt, ) if opt.non_local == "Setting_42" or opt.NL_use_mask: self.mapping_net = Mapping_Model_with_mask( min(opt.ngf * 2**opt.n_downsample_global, opt.mc), opt.map_mc, n_blocks=opt.mapping_n_block, opt=opt, ) else: self.mapping_net = Mapping_Model( min(opt.ngf * 2**opt.n_downsample_global, opt.mc), opt.map_mc, n_blocks=opt.mapping_n_block, opt=opt, ) self.mapping_net.apply(networks.weights_init) if opt.load_pretrain != "": self.load_network(self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain) if not opt.no_load_VAE: self.load_network(self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA) self.load_network(self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB) for param in self.netG_A.parameters(): param.requires_grad = False for param in self.netG_B.parameters(): param.requires_grad = False self.netG_A.eval() self.netG_B.eval() if opt.gpu_ids: self.netG_A.cuda(opt.gpu_ids[0]) self.netG_B.cuda(opt.gpu_ids[0]) self.mapping_net.cuda(opt.gpu_ids[0]) if not self.isTrain: self.load_network(self.mapping_net, "mapping_net", opt.which_epoch) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, opt.use_two_stage_mapping) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() self.criterionFeat_feat = torch.nn.L1Loss( ) if opt.use_l1_feat else torch.nn.MSELoss() if self.opt.image_L1: self.criterionImage = torch.nn.L1Loss() else: self.criterionImage = torch.nn.SmoothL1Loss() print(self.criterionFeat_feat) if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake', 'Smooth_L1', 'G_Feat_L2_Stage_1') # initialize optimizers # optimizer G if opt.no_TTUR: beta1, beta2 = opt.beta1, 0.999 G_lr, D_lr = opt.lr, opt.lr else: beta1, beta2 = 0, 0.9 G_lr, D_lr = opt.lr / 2, opt.lr * 2 if not opt.no_load_VAE: params = list(self.mapping_net.parameters()) self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2)) print("---------- Optimizers initialized -------------") def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor( torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input( label, inst, image, feat) # Fake Generation input_concat = input_label label_feat = self.netG_A.forward(input_concat, flow='enc') # print('label:') # print(label_feat.min(), label_feat.max(), label_feat.mean()) #label_feat = label_feat / 16.0 if self.opt.NL_use_mask: label_feat_map = self.mapping_net(label_feat.detach(), inst) else: label_feat_map = self.mapping_net(label_feat.detach()) fake_image = self.netG_B.forward(label_feat_map, flow='dec') image_feat = self.netG_B.forward(real_image, flow='enc') loss_feat_l2_stage_1 = 0 loss_feat_l2 = self.criterionFeat_feat( label_feat_map, image_feat.data) * self.opt.l2_feat if self.opt.feat_gan: # Fake Detection and Loss pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(label_feat.detach(), image_feat) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward( torch.cat((label_feat.detach(), label_feat_map), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) else: # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss if pair: pred_real = self.discriminate(input_label, real_image) else: pred_real = self.discriminate(last_label, last_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward( torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss and pair: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i]) - 1): tmp = self.criterionFeat( pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat loss_G_GAN_Feat += D_weights * feat_weights * tmp else: loss_G_GAN_Feat = torch.zeros(1).to(label.device) # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG( fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros( 1).to(label.device) smooth_l1_loss = 0 if self.opt.Smooth_L1: smooth_l1_loss = self.criterionImage( fake_image, real_image) * self.opt.L1_weight return [ self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, smooth_l1_loss, loss_feat_l2_stage_1), None if not infer else fake_image ] def inference(self, label, inst): use_gpu = len(self.opt.gpu_ids) > 0 if use_gpu: input_concat = label.data.cuda() inst_data = inst.cuda() else: input_concat = label.data inst_data = inst label_feat = self.netG_A.forward(input_concat, flow="enc") if self.opt.NL_use_mask: label_feat_map = self.mapping_net(label_feat.detach(), inst_data) else: label_feat_map = self.mapping_net(label_feat.detach()) fake_image = self.netG_B.forward(label_feat_map, flow="dec") return fake_image
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.no_ganFeat_loss = opt.no_ganFeat_loss # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, not opt.no_ganFeat_loss) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1).data) pred_fake_pool = self.netD.forward(fake_AB) self.loss_D_fake = 0 self.loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) / 2.0 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) self.pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(self.pred_fake, True) # Feature matching self.loss_G_GAN_Feat = 0 if not self.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(self.pred_fake[i]) - 1): self.loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionL1(self.pred_fake[i][j], self.pred_real[i][j].detach()) * self.opt.lambda_feat # Second, G(A) = B2 self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_GAN_Feat self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_GAN_feature', self.loss_G_GAN_Feat.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) if not self.no_ganFeat_loss else OrderedDict( [('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)
class cGANModel(BaseModel): def name(self): return 'cGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks if self.opt.conv3d: self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, gpu_ids=self.gpu_ids) disc_ch = opt.input_nc if self.isTrain: use_sigmoid = opt.no_lsgan if self.opt.conditional: if opt.which_model_preNet != 'none': self.preNet_A = networks.define_preNet( disc_ch + disc_ch, disc_ch + disc_ch, which_model_preNet=opt.which_model_preNet, norm=opt.norm, gpu_ids=self.gpu_ids) nif = disc_ch + disc_ch netD_norm = opt.norm self.netD = networks.define_D(nif, opt.ndf, opt.which_model_netD, opt.n_layers_D, netD_norm, use_sigmoid, gpu_ids=self.gpu_ids) else: self.netD = networks.define_D(disc_ch, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, gpu_ids=self.gpu_ids) if not self.isTrain or opt.continue_train: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch) self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers if self.opt.conv3d: self.optimizer_G_3d = torch.optim.Adam( self.netG_3d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.which_model_preNet != 'none': self.optimizer_preA = torch.optim.Adam( self.preNet_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') if self.opt.conv3d: networks.print_network(self.netG_3d) networks.print_network(self.netG) if opt.which_model_preNet != 'none': networks.print_network(self.preNet_A) networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): input_A = input['A'] input_B = input['B'] # print("input_A: ", input_A.size()) # print("input_B: ", input_B.size()) self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths'] def forward(self): self.real_A = self.input_A # print("self real_A size", self.real_A.size()) # b 26 64 64 if self.opt.conv3d: self.real_A_indep = self.netG_3d.forward(self.real_A.unsqueeze(2)) # print("self real_A_indep size", self.real_A_indep.size()) # b 26 1 64 64 self.fake_B = self.netG.forward(self.real_A_indep.squeeze(2)) # print("self fake_B size", self.fake_B.size()) # b 26 64 64 else: self.fake_B = self.netG.forward(self.real_A) # print("self fake_B size", self.fake_B.size()) # b 26 64 64 self.real_B = self.input_B # print("self real_B size", self.real_B.size()) # b 26 64 64 # real_B = util.tensor2im(self.real_B.data) # real_A = util.tensor2im(self.real_A.data) def add_noise_disc(self, real): # add noise to the discriminator target labels # real: True/False? if self.opt.noisy_disc: rand_lbl = random.random() if rand_lbl < 0.6: label = (not real) else: label = (real) else: label = (real) return label # no backprop gradients def test(self): with torch.no_grad(): self.real_A = self.input_A if self.opt.conv3d: self.real_A_indep = self.netG_3d.forward( self.real_A.unsqueeze(2)) self.fake_B = self.netG.forward(self.real_A_indep.squeeze(2)) else: self.fake_B = self.netG.forward(self.real_A) self.real_B = self.input_B # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B label_fake = self.add_noise_disc(False) b, c, m, n = self.fake_B.size() # rgb = 3 if self.opt.rgb else 1 self.fake_B_reshaped = self.fake_B # b 26 64 64 self.real_A_reshaped = self.real_A # b 26 64 64 self.real_B_reshaped = self.real_B # b 26 64 64 if self.opt.conditional: fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A_reshaped, self.fake_B_reshaped), 1)) self.pred_fake_patch = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake_patch, label_fake) if self.opt.which_model_preNet != 'none': # transform the input transformed_AB = self.preNet_A.forward(fake_AB.detach()) self.pred_fake = self.netD.forward(transformed_AB) self.loss_D_fake += self.criterionGAN(self.pred_fake, label_fake) else: self.pred_fake = self.netD.forward(self.fake_B.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, label_fake) # Real label_real = self.add_noise_disc(True) if self.opt.conditional: real_AB = torch.cat((self.real_A_reshaped, self.real_B_reshaped), 1) # .detach() self.pred_real_patch = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real_patch, label_real) if self.opt.which_model_preNet != 'none': # transform the input transformed_A_real = self.preNet_A.forward(real_AB) self.pred_real = self.netD.forward(transformed_A_real) self.loss_D_real += self.criterionGAN(self.pred_real, label_real) else: self.pred_real = self.netD.forward(self.real_B) self.loss_D_real = self.criterionGAN(self.pred_real, label_real) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator if self.opt.conditional: # PATCH GAN fake_AB = (torch.cat((self.real_A_reshaped, self.fake_B_reshaped), 1)) pred_fake_patch = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake_patch, True) if self.opt.which_model_preNet != 'none': # global disc transformed_A = self.preNet_A.forward(fake_AB) pred_fake = self.netD.forward(transformed_A) self.loss_G_GAN += self.criterionGAN(pred_fake, True) else: pred_fake = self.netD.forward(self.fake_B) self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() if self.opt.which_model_preNet != 'none': self.optimizer_preA.zero_grad() self.backward_D() self.optimizer_D.step() if self.opt.which_model_preNet != 'none': self.optimizer_preA.step() self.optimizer_G.zero_grad() if self.opt.conv3d: self.optimizer_G_3d.zero_grad() self.backward_G() self.optimizer_G.step() if self.opt.conv3d: self.optimizer_G_3d.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.item()), ('G_L1', self.loss_G_L1.item()), ('D_real', self.loss_D_real.item()), ('D_fake', self.loss_D_fake.item())]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): if self.opt.conv3d: self.save_network(self.netG_3d, 'G_3d', label, gpu_ids=self.gpu_ids) self.save_network(self.netG, 'G', label, gpu_ids=self.gpu_ids) self.save_network(self.netD, 'D', label, gpu_ids=self.gpu_ids) if self.opt.which_model_preNet != 'none': self.save_network(self.preNet_A, 'PRE_A', label, gpu_ids=self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr if self.opt.which_model_preNet != 'none': for param_group in self.optimizer_preA.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.conv3d: for param_group in self.optimizer_G_3d.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) self.A_mask = self.Tensor(nb, opt.input_nc, size, size) self.B_mask = self.Tensor(nb, opt.input_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A1 = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D1, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_A2 = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D2, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B1 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D1, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B2 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D2, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A1, 'D_A1', which_epoch) self.load_network(self.netD_A2, 'D_A2', which_epoch) self.load_network(self.netD_B1, 'D_B1', which_epoch) self.load_network(self.netD_B2, 'D_B2', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A1 = torch.optim.Adam(self.netD_A1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A2 = torch.optim.Adam(self.netD_A2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B1 = torch.optim.Adam(self.netD_B1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B2 = torch.optim.Adam(self.netD_B2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A1) self.optimizers.append(self.optimizer_D_A2) self.optimizers.append(self.optimizer_D_B1) self.optimizers.append(self.optimizer_D_B2) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A1) networks.print_network(self.netD_A2) networks.print_network(self.netD_B1) networks.print_network(self.netD_B2) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if self.opt.face_mask: A_mask = input['A_mask'] B_mask = input['B_mask'] self.A_mask = Variable(A_mask, requires_grad=False).cuda() self.B_mask = Variable(B_mask, requires_grad=False).cuda() def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A1(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A1 = self.backward_D_basic(self.netD_A1, self.real_B, fake_B) def backward_D_A2(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A2 = self.backward_D_basic(self.netD_A2, self.real_B, fake_B) def backward_D_B1(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B1 = self.backward_D_basic(self.netD_B1, self.real_A, fake_A) def backward_D_B2(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B2 = self.backward_D_basic(self.netD_B2, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake1 = self.netD_A1.forward(self.fake_B) pred_fake2 = self.netD_A2.forward(self.fake_B) self.loss_G_A = (self.criterionGAN(pred_fake1, True) + self.criterionGAN(pred_fake2, True)) * 0.5 # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake1 = self.netD_B1.forward(self.fake_A) pred_fake2 = self.netD_B2.forward(self.fake_A) self.loss_G_B = (self.criterionGAN(pred_fake1, True) + self.criterionGAN(pred_fake2, True)) * 0.5 # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) if not self.opt.face_mask: self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A else: self.loss_cycle_A = ((self.rec_A - self.real_A).abs() * self.A_mask).mean() * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) if not self.opt.face_mask: self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B else: self.loss_cycle_B = ((self.rec_B - self.real_B).abs() * self.B_mask).mean() * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A1 self.optimizer_D_A1.zero_grad() self.backward_D_A1() self.optimizer_D_A1.step() # D_A2 self.optimizer_D_A2.zero_grad() self.backward_D_A2() self.optimizer_D_A2.step() # D_B1 self.optimizer_D_B1.zero_grad() self.backward_D_B1() self.optimizer_D_B1.step() # D_B2 self.optimizer_D_B2.zero_grad() self.backward_D_B2() self.optimizer_D_B2.step() def get_current_errors(self): D_A1 = self.loss_D_A1.data[0] D_A2 = self.loss_D_A2.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B1 = self.loss_D_B1.data[0] D_B2 = self.loss_D_B2.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A1', D_A1), ('D_A2', D_A2), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B1', D_B1), ('D_B2', D_B2), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A1', D_A1), ('D_A2', D_A2), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B1', D_B1), ('D_B2', D_B2), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.face_mask: mask_A = util.mask2im(self.A_mask.data, face_weight=self.opt.face_weight) mask_B = util.mask2im(self.B_mask.data, face_weight=self.opt.face_weight) if self.opt.isTrain and self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) elif self.opt.face_mask: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('mask_A', mask_A), ('mask_B', mask_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A1, 'D_A1', label, self.gpu_ids) self.save_network(self.netD_A2, 'D_A2', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B1, 'D_B1', label, self.gpu_ids) self.save_network(self.netD_B2, 'D_B2', label, self.gpu_ids)
class DCLModel(BaseModel): """ This class implements DCLGAN model. This code is inspired by CUT and CycleGAN. """ @staticmethod def modify_commandline_options(parser, is_train=True): """ Configures options specific for DCLGAN """ parser.add_argument('--DCL_mode', type=str, default="DCL", choices='DCL') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))') parser.add_argument('--lambda_NCE', type=float, default=2.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_IDT', type=float, default=1.0, help='weight for l1 identical loss: (G(X),X)') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_layers', type=str, default='4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--nce_includes_all_negatives_from_minibatch', type=util.str2bool, nargs='?', const=True, default=False, help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') parser.add_argument('--netF_nc', type=int, default=256) parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss') parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer') parser.add_argument('--flip_equivariance', type=util.str2bool, nargs='?', const=True, default=False, help="Enforce flip-equivariance as additional regularization.") parser.set_defaults(pool_size=0) # no image pooling opt, _ = parser.parse_known_args() # Set default parameters for DCLGAN. if opt.DCL_mode.lower() == "dcl": parser.set_defaults(nce_idt=True, lambda_NCE=2.0) else: raise ValueError(opt.DCL_mode) return parser def __init__(self, opt): BaseModel.__init__(self, opt) # specify the training losses you want to print out. # The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['D_A', 'G_A', 'NCE1', 'D_B', 'G_B', 'NCE2', 'G'] visual_names_A = ['real_A', 'fake_B'] visual_names_B = ['real_B', 'fake_A'] self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] if opt.nce_idt and self.isTrain: self.loss_names += ['idt_B', 'idt_A'] visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B if self.isTrain: self.model_names = ['G_A', 'F1', 'D_A', 'G_B', 'F2', 'D_B'] else: # during test time, only load G self.model_names = ['G_A', 'G_B'] # define networks (both generator and discriminator) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) self.netF1 = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netF2 = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) if self.isTrain: self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionNCE = [] for nce_layer in self.nce_layers: self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) self.criterionIdt = torch.nn.L1Loss().to(self.device) self.criterionSim = torch.nn.L1Loss('sum').to(self.device) self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def data_dependent_initialize(self, data): """ The feature network netF is defined in terms of the shape of the intermediate, extracted features of the encoder portion of netG. Because of this, the weights of netF are initialized at the first feedforward pass with some input images. Please also see PatchSampleF.create_mlp(), which is called at the first forward() call. """ self.set_input(data) bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1) self.real_A = self.real_A[:bs_per_gpu] self.real_B = self.real_B[:bs_per_gpu] self.forward() # compute fake images: G(A) if self.opt.isTrain: self.compute_G_loss().backward() # calculate graidents for G self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_F = torch.optim.Adam(itertools.chain(self.netF1.parameters(), self.netF2.parameters())) self.optimizers.append(self.optimizer_F) def optimize_parameters(self): # forward self.forward() # update D self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update G self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() if self.opt.netF == 'mlp_sample': self.optimizer_F.zero_grad() self.loss_G = self.compute_G_loss() self.loss_G.backward() self.optimizer_G.step() if self.opt.netF == 'mlp_sample': self.optimizer_F.step() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.fake_A = self.netG_B(self.real_B) # G_B(B) if self.opt.nce_idt: self.idt_A = self.netG_A(self.real_B) self.idt_B = self.netG_B(self.real_A) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) * self.opt.lambda_GAN def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) * self.opt.lambda_GAN def compute_G_loss(self): """Calculate GAN and NCE loss for the generator""" fakeB = self.fake_B fakeA = self.fake_A # First, G(A) should fake the discriminator if self.opt.lambda_GAN > 0.0: pred_fakeB = self.netD_A(fakeB) pred_fakeA = self.netD_B(fakeA) self.loss_G_A = self.criterionGAN(pred_fakeB, True).mean() * self.opt.lambda_GAN self.loss_G_B = self.criterionGAN(pred_fakeA, True).mean() * self.opt.lambda_GAN else: self.loss_G_A = 0.0 self.loss_G_B = 0.0 if self.opt.lambda_NCE > 0.0: self.loss_NCE1 = self.calculate_NCE_loss1(self.real_A, self.fake_B) * self.opt.lambda_NCE self.loss_NCE2 = self.calculate_NCE_loss2(self.real_B, self.fake_A) * self.opt.lambda_NCE else: self.loss_NCE1, self.loss_NCE_bd, self.loss_NCE2 = 0.0, 0.0, 0.0 if self.opt.lambda_NCE > 0.0: # L1 IDENTICAL Loss self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * self.opt.lambda_IDT self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * self.opt.lambda_IDT loss_NCE_both = (self.loss_NCE1 + self.loss_NCE2) * 0.5 + (self.loss_idt_A + self.loss_idt_B) * 0.5 else: loss_NCE_both = (self.loss_NCE1 + self.loss_NCE2) * 0.5 self.loss_G = (self.loss_G_A + self.loss_G_B) * 0.5 + loss_NCE_both return self.loss_G def calculate_NCE_loss1(self, src, tgt): n_layers = len(self.nce_layers) feat_q = self.netG_B(tgt, self.nce_layers, encode_only=True) feat_k = self.netG_A(src, self.nce_layers, encode_only=True) feat_k_pool, sample_ids = self.netF1(feat_k, self.opt.num_patches, None) feat_q_pool, _ = self.netF2(feat_q, self.opt.num_patches, sample_ids) total_nce_loss = 0.0 for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers): loss = crit(f_q, f_k) total_nce_loss += loss.mean() return total_nce_loss / n_layers def calculate_NCE_loss2(self, src, tgt): n_layers = len(self.nce_layers) feat_q = self.netG_A(tgt, self.nce_layers, encode_only=True) feat_k = self.netG_B(src, self.nce_layers, encode_only=True) feat_k_pool, sample_ids = self.netF2(feat_k, self.opt.num_patches, None) feat_q_pool, _ = self.netF1(feat_q, self.opt.num_patches, sample_ids) total_nce_loss = 0.0 for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers): loss = crit(f_q, f_k) total_nce_loss += loss.mean() return total_nce_loss / n_layers def generate_visuals_for_evaluation(self, data, mode): with torch.no_grad(): visuals = {} AtoB = self.opt.direction == "AtoB" G = self.netG_A source = data["A" if AtoB else "B"].to(self.device) if mode == "forward": visuals["fake_B"] = G(source) else: raise ValueError("mode %s is not recognized" % mode) return visuals
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleGANSemanticMaskModel(BaseModel): def name(self): return 'CycleGANSemanticMaskModel' # new, copied from cyclegansemantic model @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ parser.set_defaults( no_dropout=True) # default CycleGAN did not use dropout if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help= 'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument('--out_mask', action='store_true', help='use loss out mask') parser.add_argument('--lambda_out_mask', type=float, default=10.0, help='weight for loss out mask') parser.add_argument('--loss_out_mask', type=str, default='L1', help='loss mask') parser.add_argument('--charbonnier_eps', type=float, default=1e-6, help='Charbonnier loss epsilon value') parser.add_argument('--disc_in_mask', action='store_true', help='use in-mask discriminator') parser.add_argument( '--train_f_s_B', action='store_true', help= 'if true f_s will be trained not only on domain A but also on domain B' ) parser.add_argument( '--fs_light', action='store_true', help='whether to use a light (unet) network for f_s') parser.add_argument('--lr_f_s', type=float, default=0.0002, help='f_s learning rate') parser.add_argument( '--D_noise', type=float, default=0.0, help='whether to add instance noise to discriminator inputs') parser.add_argument( '--D_label_smooth', action='store_true', help= 'whether to use one-sided label smoothing with discriminator') parser.add_argument('--rec_noise', type=float, default=0.0, help='whether to add noise to reconstruction') parser.add_argument('--nb_attn', type=int, default=10, help='number of attention masks') parser.add_argument( '--nb_mask_input', type=int, default=1, help= 'number of attention masks which will be applied on the input image' ) parser.add_argument('--lambda_sem', type=float, default=1.0, help='weight for semantic loss') return parser def __init__(self, opt): BaseModel.__init__(self, opt) if not hasattr(opt, 'disc_in_mask'): opt.disc_in_mask = False if not hasattr(opt, 'out_mask'): opt.out_mask = False if not hasattr(opt, 'nb_attn'): opt.nb_attn = 10 if not hasattr(opt, 'nb_mask_input'): opt.nb_mask_input = 1 if not hasattr(opt, 'fs_light'): opt.fs_light = False # specify the training losses you want to print out. The program will call base_model.get_current_losses losses = ['G_A', 'G_B'] if opt.disc_in_mask: losses += ['D_A_mask', 'D_B_mask'] losses += ['D_A', 'D_B'] if opt.out_mask: losses += ['out_mask_AB', 'out_mask_BA'] losses += [ 'cycle_A', 'idt_A', 'cycle_B', 'idt_B', 'sem_AB', 'sem_BA', 'f_s' ] self.loss_names = losses # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_B') visual_names_B.append('idt_A') # inverted for original visual_names_seg_A = ['input_A_label', 'gt_pred_A', 'pfB_max'] visual_names_seg_B = ['gt_pred_B', 'pfA_max'] visual_names_out_mask = ['real_A_out_mask', 'fake_B_out_mask'] if hasattr(self, 'input_B_label') and len( self.input_B_label ) > 0: # XXX: model is created after dataset is populated so this check stands visual_names_seg_B.append('input_B_label') visual_names_out_mask.append('real_B_out_mask') visual_names_out_mask.append('fake_A_out_mask') visual_names_mask = ['fake_B_mask', 'fake_A_mask'] visual_names_mask_in = [ 'real_B_mask', 'fake_B_mask', 'real_A_mask', 'fake_A_mask', 'real_B_mask_in', 'fake_B_mask_in', 'real_A_mask_in', 'fake_A_mask_in' ] self.visual_names = visual_names_A + visual_names_B + visual_names_seg_A + visual_names_seg_B if opt.out_mask: self.visual_names += visual_names_out_mask if opt.disc_in_mask: self.visual_names += visual_names_mask_in # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'f_s'] if opt.disc_in_mask: self.model_names += ['D_A_mask', 'D_B_mask'] self.model_names += ['D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.G_spectral, opt.init_type, opt.init_gain, self.gpu_ids, nb_attn=opt.nb_attn, nb_mask_input=opt.nb_mask_input) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.G_spectral, opt.init_type, opt.init_gain, self.gpu_ids, nb_attn=opt.nb_attn, nb_mask_input=opt.nb_mask_input) if self.isTrain: self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain, self.gpu_ids) if opt.disc_in_mask: self.netD_A_mask = networks.define_D( opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B_mask = networks.define_D( opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain, self.gpu_ids) self.netf_s = networks.define_f(opt.input_nc, nclasses=opt.semantic_nclasses, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, fs_light=opt.fs_light) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images if opt.disc_in_mask: self.fake_A_pool_mask = ImagePool(opt.pool_size) self.fake_B_pool_mask = ImagePool(opt.pool_size) # define loss functions if opt.D_label_smooth: target_real_label = 0.9 else: target_real_label = 1.0 self.criterionGAN = loss.GANLoss( opt.gan_mode, target_real_label=target_real_label).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionf_s = torch.nn.modules.CrossEntropyLoss() if opt.out_mask: if opt.loss_out_mask == 'L1': self.criterionMask = torch.nn.L1Loss() elif opt.loss_out_mask == 'MSE': self.criterionMask = torch.nn.MSELoss() elif opt.loss_out_mask == 'Charbonnier': self.criterionMask = L1_Charbonnier_loss( opt.charbonnier_eps) # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.disc_in_mask: self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters(), self.netD_A_mask.parameters(), self.netD_B_mask.parameters()), lr=opt.D_lr, betas=(opt.beta1, 0.999)) else: self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.D_lr, betas=(opt.beta1, 0.999)) self.optimizer_f_s = torch.optim.Adam(self.netf_s.parameters(), lr=opt.lr_f_s, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.rec_noise = opt.rec_noise self.D_noise = opt.D_noise def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if 'A_label' in input: #self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device) self.input_A_label = input['A_label'].to(self.device).squeeze(1) #self.input_A_label_dis = display_mask(self.input_A_label) if 'B_label' in input and len(input['B_label']) > 0: self.input_B_label = input['B_label'].to(self.device).squeeze( 1) # beniz: unused #self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain def forward(self): self.fake_B = self.netG_A(self.real_A) d = 1 if self.isTrain: if self.rec_noise > 0.0: self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise) self.rec_A = self.netG_B(self.fake_B_noisy1) else: self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) if self.rec_noise > 0.0: self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise) self.rec_B = self.netG_A(self.fake_A_noisy1) else: self.rec_B = self.netG_A(self.fake_A) self.pred_real_A = self.netf_s(self.real_A) self.gt_pred_A = F.log_softmax(self.pred_real_A, dim=d).argmax(dim=d) self.pred_real_B = self.netf_s(self.real_B) self.gt_pred_B = F.log_softmax(self.pred_real_B, dim=d).argmax(dim=d) self.pred_fake_A = self.netf_s(self.fake_A) self.pfA = F.log_softmax(self.pred_fake_A, dim=d) #.argmax(dim=d) self.pfA_max = self.pfA.argmax(dim=d) if hasattr(self, 'criterionMask'): label_A = self.input_A_label label_A_in = label_A.unsqueeze(1) label_A_inv = torch.tensor(np.ones(label_A.size())).to( self.device) - label_A > 0 label_A_inv = label_A_inv.unsqueeze(1) #label_A_inv = torch.cat ([label_A_inv,label_A_inv,label_A_inv],1) self.real_A_out_mask = self.real_A * label_A_inv self.fake_B_out_mask = self.fake_B * label_A_inv if self.disc_in_mask: self.real_A_mask_in = self.real_A * label_A_in self.fake_B_mask_in = self.fake_B * label_A_in self.real_A_mask = self.real_A #* label_A_in + self.real_A_out_mask self.fake_B_mask = self.fake_B_mask_in + self.real_A_out_mask.float( ) if self.D_noise > 0.0: self.fake_B_noisy = gaussian(self.fake_B, self.D_noise) self.real_A_noisy = gaussian(self.real_A, self.D_noise) if hasattr(self, 'input_B_label') and len(self.input_B_label) > 0: label_B = self.input_B_label label_B_in = label_B.unsqueeze(1) label_B_inv = torch.tensor(np.ones(label_B.size())).to( self.device) - label_B > 0 label_B_inv = label_B_inv.unsqueeze(1) self.real_B_out_mask = self.real_B * label_B_inv self.fake_A_out_mask = self.fake_A * label_B_inv if self.disc_in_mask: self.real_B_mask_in = self.real_B * label_B_in self.fake_A_mask_in = self.fake_A * label_B_in self.real_B_mask = self.real_B #* label_B_in + self.real_B_out_mask self.fake_A_mask = self.fake_A_mask_in + self.real_B_out_mask.float( ) if self.D_noise > 0.0: self.fake_A_noisy = gaussian(self.fake_A, self.D_noise) self.real_B_noisy = gaussian(self.real_B, self.D_noise) self.pred_fake_B = self.netf_s(self.fake_B) self.pfB = F.log_softmax(self.pred_fake_B, dim=d) #.argmax(dim=d) self.pfB_max = self.pfB.argmax(dim=d) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_f_s(self): #print('backward fs') label_A = self.input_A_label # forward only real source image through semantic classifier pred_A = self.netf_s(self.real_A) self.loss_f_s = self.criterionf_s(pred_A, label_A) #.squeeze(1)) if self.opt.train_f_s_B: label_B = self.input_B_label pred_B = self.netf_s(self.real_B) self.loss_f_s += self.criterionf_s(pred_B, label_B) #.squeeze(1)) self.loss_f_s.backward() def backward_D_A(self): if self.D_noise > 0.0: fake_B = self.fake_B_pool.query(self.fake_B_noisy) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B_noisy, fake_B) else: fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): if self.D_noise > 0.0: fake_A = self.fake_A_pool.query(self.fake_A_noisy) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A_noisy, fake_A) else: fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_D_A_mask(self): fake_B_mask = self.fake_B_pool_mask.query(self.fake_B_mask) self.loss_D_A_mask = self.backward_D_basic(self.netD_A_mask, self.real_B_mask, fake_B_mask) def backward_D_B_mask(self): fake_A_mask = self.fake_A_pool_mask.query(self.fake_A_mask) self.loss_D_B_mask = self.backward_D_basic(self.netD_B_mask, self.real_A_mask, fake_A_mask) def backward_D_A_mask_in(self): fake_B_mask_in = self.fake_B_pool.query(self.fake_B_mask_in) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B_mask_in, fake_B_mask_in) def backward_D_B_mask_in(self): fake_A_mask_in = self.fake_A_pool.query(self.fake_A_mask) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A_mask_in, fake_A_mask_in) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_sem = self.opt.lambda_sem # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 if self.disc_in_mask: self.loss_G_A_mask = self.criterionGAN( self.netD_A(self.fake_B_mask_in), True) self.loss_G_B_mask = self.criterionGAN( self.netD_B(self.fake_A_mask_in), True) self.loss_G_A = self.criterionGAN( self.netD_A_mask(self.fake_B_mask), True) self.loss_G_B = self.criterionGAN( self.netD_B_mask(self.fake_A_mask), True) else: # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss standard cyclegan self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B if self.disc_in_mask: self.loss_G += self.loss_G_A_mask + self.loss_G_B_mask # semantic loss AB self.loss_sem_AB = lambda_sem * self.criterionf_s( self.pfB, self.input_A_label) # semantic loss BA if hasattr(self, 'input_B_label'): self.loss_sem_BA = lambda_sem * self.criterionf_s( self.pfA, self.input_B_label) #.squeeze(1)) else: self.loss_sem_BA = lambda_sem * self.criterionf_s( self.pfA, self.gt_pred_B) #.squeeze(1)) # only use semantic loss when classifier has reasonably low loss #if True: if not hasattr(self, 'loss_f_s') or self.loss_f_s.detach().item() > 1.0: self.loss_sem_AB = 0 * self.loss_sem_AB self.loss_sem_BA = 0 * self.loss_sem_BA self.loss_G += self.loss_sem_BA + self.loss_sem_AB lambda_out_mask = self.opt.lambda_out_mask if hasattr(self, 'criterionMask'): self.loss_out_mask_AB = self.criterionMask( self.real_A_out_mask, self.fake_B_out_mask) * lambda_out_mask if hasattr(self, 'input_B_label') and len(self.input_B_label) > 0: self.loss_out_mask_BA = self.criterionMask( self.real_B_out_mask, self.fake_A_out_mask) * lambda_out_mask else: self.loss_out_mask_BA = 0 self.loss_G += self.loss_out_mask_AB + self.loss_out_mask_BA self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B if self.disc_in_mask: self.set_requires_grad( [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask], False) else: self.set_requires_grad( [self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.set_requires_grad([self.netG_A, self.netG_B], True) self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B if self.disc_in_mask: self.set_requires_grad( [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask], True) else: self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero if self.disc_in_mask: self.backward_D_A_mask_in() self.backward_D_B_mask_in() self.backward_D_A_mask() self.backward_D_B_mask() else: self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate gradients for D_B self.optimizer_D.step() # update D_A and D_B's weights if self.disc_in_mask: self.set_requires_grad( [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask], False) else: self.set_requires_grad([self.netD_A, self.netD_B], False) self.set_requires_grad([self.netf_s], True) # f_s self.optimizer_f_s.zero_grad() self.backward_f_s() self.optimizer_f_s.step()
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step()
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain ##### define networks # Generator network netG_input_nc = opt.label_nc if not opt.no_instance: netG_input_nc += 1 self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = 4 * opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # Face discriminator network if self.isTrain and opt.face_discrim: use_sigmoid = opt.no_lsgan netD_input_nc = 2 * opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netDface = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids, netD='face') #Face residual network if opt.face_generator: if opt.faceGtype == 'unet': self.faceGen = networks.define_G(opt.output_nc * 2, opt.output_nc, 32, 'unet', n_downsample_global=2, n_blocks_global=5, n_local_enhancers=0, n_blocks_local=0, norm=opt.norm, gpu_ids=self.gpu_ids) elif opt.faceGtype == 'global': self.faceGen = networks.define_G(opt.output_nc * 2, opt.output_nc, 64, 'global', n_downsample_global=3, n_blocks_global=5, n_local_enhancers=0, n_blocks_local=0, norm=opt.norm, gpu_ids=self.gpu_ids) else: raise ('face generator not implemented!') print('---------- Networks initialized -------------') # load networks if (not self.isTrain or opt.continue_train or opt.load_pretrain): pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if opt.face_discrim: self.load_network(self.netDface, 'Dface', opt.which_epoch, pretrained_path) if opt.face_generator: self.load_network(self.faceGen, 'Gface', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) if opt.use_l1: self.criterionL1 = torch.nn.L1Loss() # Loss names self.loss_names = [ 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake', 'G_GANface', 'D_realface', 'D_fakeface' ] # initialize optimizers # optimizer G if opt.niter_fix_global > 0: print( '------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params': [value], 'lr': opt.lr}] else: params += [{'params': [value], 'lr': 0.0}] else: params = list(self.netG.parameters()) if opt.face_generator: params = list(self.faceGen.parameters()) else: if opt.niter_fix_main == 0: params += list(self.netG.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D if opt.niter_fix_main > 0: print( '------------- Only training the face discriminator network (for %d epochs) ------------' % opt.niter_fix_main) params = list(self.netDface.parameters()) else: if opt.face_discrim: params = list(self.netD.parameters()) + list( self.netDface.parameters()) else: params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, real_image=None, next_label=None, next_image=None, zeroshere=None, infer=False): input_label = label_map.data.float() input_label = Variable(input_label, volatile=infer) # next label for training if next_label is not None: next_label = next_label.data.float() next_label = Variable(next_label, volatile=infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.float()) # real images for training if next_image is not None: next_image = Variable(next_image.data.float()) if zeroshere is not None: zeroshere = zeroshere.data.float() zeroshere = Variable(zeroshere, volatile=infer) return input_label, real_image, next_label, next_image, zeroshere def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def discriminate_4(self, s0, s1, i0, i1, use_pool=False): input_concat = torch.cat((s0, s1, i0.detach(), i1.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def discriminateface(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netDface.forward(fake_query) else: return self.netDface.forward(input_concat) def forward(self, label, next_label, image, next_image, face_coords, zeroshere, infer=False): # Encode Inputs input_label, real_image, next_label, next_image, zeroshere = self.encode_input(label, image, \ next_label=next_label, next_image=next_image, zeroshere=zeroshere) if self.opt.face_discrim: miny = face_coords.data[0][0] maxy = face_coords.data[0][1] minx = face_coords.data[0][2] maxx = face_coords.data[0][3] initial_I_0 = 0 # Fake Generation I_0 input_concat = torch.cat((input_label, zeroshere), dim=1) #face residual for I_0 face_residual_0 = 0 if self.opt.face_generator: initial_I_0 = self.netG.forward(input_concat) face_label_0 = input_label[:, :, miny:maxy, minx:maxx] face_residual_0 = self.faceGen.forward( torch.cat( (face_label_0, initial_I_0[:, :, miny:maxy, minx:maxx]), dim=1)) I_0 = initial_I_0.clone() I_0[:, :, miny:maxy, minx:maxx] = initial_I_0[:, :, miny:maxy, minx:maxx] + face_residual_0 else: I_0 = self.netG.forward(input_concat) input_concat1 = torch.cat((next_label, I_0), dim=1) #face residual for I_1 face_residual_1 = 0 if self.opt.face_generator: initial_I_1 = self.netG.forward(input_concat1) face_label_1 = next_label[:, :, miny:maxy, minx:maxx] face_residual_1 = self.faceGen.forward( torch.cat( (face_label_1, initial_I_1[:, :, miny:maxy, minx:maxx]), dim=1)) I_1 = initial_I_1.clone() I_1[:, :, miny:maxy, minx:maxx] = initial_I_1[:, :, miny:maxy, minx:maxx] + face_residual_1 else: I_1 = self.netG.forward(input_concat1) loss_D_fake_face = loss_D_real_face = loss_G_GAN_face = 0 fake_face_0 = fake_face_1 = real_face_0 = real_face_1 = 0 fake_face = real_face = face_residual = 0 if self.opt.face_discrim: fake_face_0 = I_0[:, :, miny:maxy, minx:maxx] fake_face_1 = I_1[:, :, miny:maxy, minx:maxx] real_face_0 = real_image[:, :, miny:maxy, minx:maxx] real_face_1 = next_image[:, :, miny:maxy, minx:maxx] # Fake Detection and Loss pred_fake_pool_face = self.discriminateface(face_label_0, fake_face_0, use_pool=True) loss_D_fake_face += 0.5 * self.criterionGAN( pred_fake_pool_face, False) # Face Real Detection and Loss pred_real_face = self.discriminateface(face_label_0, real_face_0) loss_D_real_face += 0.5 * self.criterionGAN(pred_real_face, True) # Face GAN loss (Fake Passability Loss) pred_fake_face = self.netDface.forward( torch.cat((face_label_0, fake_face_0), dim=1)) loss_G_GAN_face += 0.5 * self.criterionGAN(pred_fake_face, True) pred_fake_pool_face = self.discriminateface(face_label_1, fake_face_1, use_pool=True) loss_D_fake_face += 0.5 * self.criterionGAN( pred_fake_pool_face, False) # Face Real Detection and Loss pred_real_face = self.discriminateface(face_label_1, real_face_1) loss_D_real_face += 0.5 * self.criterionGAN(pred_real_face, True) # Face GAN loss (Fake Passability Loss) pred_fake_face = self.netDface.forward( torch.cat((face_label_1, fake_face_1), dim=1)) loss_G_GAN_face += 0.5 * self.criterionGAN(pred_fake_face, True) fake_face = torch.cat((fake_face_0, fake_face_1), dim=3) real_face = torch.cat((real_face_0, real_face_1), dim=3) if self.opt.face_generator: face_residual = torch.cat((face_residual_0, face_residual_1), dim=3) # Fake Detection and Loss pred_fake_pool = self.discriminate_4(input_label, next_label, I_0, I_1, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate_4(input_label, next_label, real_image, next_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward( torch.cat((input_label, next_label, I_0, I_1), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i]) - 1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG0 = self.criterionVGG(I_0, real_image) * self.opt.lambda_feat loss_G_VGG1 = self.criterionVGG(I_1, next_image) * self.opt.lambda_feat loss_G_VGG = loss_G_VGG0 + loss_G_VGG1 if self.opt.netG == 'global': #need 2x VGG for artifacts when training local loss_G_VGG *= 0.5 if self.opt.face_discrim: loss_G_VGG += 0.5 * self.criterionVGG( fake_face_0, real_face_0) * self.opt.lambda_feat loss_G_VGG += 0.5 * self.criterionVGG( fake_face_1, real_face_1) * self.opt.lambda_feat if self.opt.use_l1: loss_G_VGG += (self.criterionL1(I_1, next_image)) * self.opt.lambda_A # Only return the fake_B image if necessary to save BW return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, \ loss_G_GAN_face, loss_D_real_face, loss_D_fake_face], \ None if not infer else [torch.cat((I_0, I_1), dim=3), fake_face, face_residual, initial_I_0] ] def inference(self, label, prevouts, face_coords): # Encode Inputs input_label, _, _, _, prevouts = self.encode_input( Variable(label), zeroshere=Variable(prevouts), infer=True) if self.opt.face_generator: miny = face_coords[0][0] maxy = face_coords[0][1] minx = face_coords[0][2] maxx = face_coords[0][3] """ new face """ I_0 = 0 # Fake Generation input_concat = torch.cat((input_label, prevouts), dim=1) initial_I_0 = self.netG.forward(input_concat) if self.opt.face_generator: face_label_0 = input_label[:, :, miny:maxy, minx:maxx] face_residual_0 = self.faceGen.forward( torch.cat( (face_label_0, initial_I_0[:, :, miny:maxy, minx:maxx]), dim=1)) I_0 = initial_I_0.clone() I_0[:, :, miny:maxy, minx:maxx] = initial_I_0[:, :, miny:maxy, minx:maxx] + face_residual_0 fake_face_0 = I_0[:, :, miny:maxy, minx:maxx] return I_0 return initial_I_0 def get_edges(self, t): edge = torch.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.opt.face_discrim: self.save_network(self.netDface, 'Dface', which_epoch, self.gpu_ids) if self.opt.face_generator: self.save_network(self.faceGen, 'Gface', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.opt.face_generator: params += list(self.faceGen.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) print('------------ Now also finetuning global generator -----------') def update_fixed_params_netD(self): params = list(self.netD.parameters()) + list( self.netDface.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) print( '------------ Now also finetuning multiscale discriminator -----------' ) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_shift_A', type=float, default=0.003, help='weight for shift loss for A') parser.add_argument('--lambda_shift_B', type=float, default=0.003, help='weight for shift loss for B') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'shift_A', 'shift_B'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionShift = torch.nn.MSELoss(size_average=False) self.shift_transform = torchsample.transforms.RandomTranslate((1./8., 1./8.)) # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def inference(self, direction, image): if direction not in ['AtoB', 'BtoA']: raise ValueError('{} is not a valid direction'.format(direction)) with torch.no_grad(): #image = torch.from_numpy(image.copy()).to(self.device) if direction == 'AtoB': return self.netG_B(image) else: return self.netG_A(image) def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_shift_A = self.opt.lambda_shift_A lambda_shift_B = self.opt.lambda_shift_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B #Shift losses from VR-Goggles for Robots real_A = self.real_A.cpu() #((self.real_A + 1.) / 2. * 255) #.int().numpy() real_B = self.real_B.cpu() #((self.real_B + 1.) / 2. * 255.) #.int().numpy() #print(self.real_A[0].shape, type(self.real_A[0]), self.real_A[0]) real_A = torch.unbind(real_A, 0) real_B = torch.unbind(real_B, 0) fake_A = self.fake_A.cpu() fake_B = self.fake_B.cpu() fake_A = torch.unbind(fake_A, 0) fake_B = torch.unbind(fake_B, 0) shifted_real_A, height_A, width_A = self.shift_transform(*real_A) shifted_real_B, height_B, width_B = self.shift_transform(*real_B) gen_B = self.netG_A(torch.stack(shifted_real_A, 0).cuda()) gen_A = self.netG_B(torch.stack(shifted_real_B, 0).cuda()) shifted_fake_A, _, _ = self.shift_transform(*fake_A, random_height=height_B, random_width=width_B) # netG_B shifted_fake_B, _, _ = self.shift_transform(*fake_B, random_height=height_A, random_width=width_A) # netG_A shifted_fake_A = torch.stack(shifted_fake_A).cuda() shifted_fake_B = torch.stack(shifted_fake_B).cuda() """" import cv2 import numpy as np cv2.imshow('shifted_real_A', ((shifted_real_A[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0])) cv2.imshow('real_A', ((real_A[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0])) cv2.imshow('fake_B', ((fake_B[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0])) cv2.imshow('gen_B', ((gen_B.detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8)[0].transpose([1,2,0])) cv2.imshow('shifted_fake_B', ((shifted_fake_B[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0])) cv2.waitKey(1) """ self.loss_shift_A = self.criterionShift(shifted_fake_A, gen_A) * lambda_shift_A self.loss_shift_B = self.criterionShift(shifted_fake_B, gen_B) * lambda_shift_B #print(self.criterionShift(shifted_fake_A, gen_A), self.criterionShift(shifted_fake_B, gen_B)) # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + \ self.loss_idt_A + self.loss_idt_B + self.loss_shift_A + self.loss_shift_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step()
class StackGANModel(BaseModel): def name(self): return 'StackGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # define tensors self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.input_base = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks if self.opt.conv3d: # one layer for considering a conv filter for each of the 26 channels self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids) # Generator of the GlyphNet self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) # Generator of the OrnaNet as an Encoder and a Decoder self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) if self.opt.conditional: # not applicable for non-conditional case use_sigmoid = opt.no_lsgan if opt.which_model_preNet != 'none': self.preNet_A = networks.define_preNet( self.opt.input_nc_1 + self.opt.output_nc_1, self.opt.input_nc_1 + self.opt.output_nc_1, which_model_preNet=opt.which_model_preNet, norm=opt.norm, gpu_ids=self.gpu_ids) nif = opt.input_nc_1 + opt.output_nc_1 netD_norm = opt.norm self.netD1 = networks.define_D(nif, opt.ndf, opt.which_model_netD, opt.n_layers_D, netD_norm, use_sigmoid, True, self.gpu_ids) if self.isTrain: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) if self.opt.print_weights: for key in self.netE1.state_dict().keys(): print(key, 'random_init, mean, std:', torch.mean(self.netE1.state_dict()[key]), torch.std(self.netE1.state_dict()[key])) for key in self.netDE1.state_dict().keys(): print(key, 'random_init, mean, std:', torch.mean(self.netDE1.state_dict()[key]), torch.std(self.netDE1.state_dict()[key])) if not self.isTrain: print("Load generators from their pretrained models...") if opt.no_Style2Glyph: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) else: if self.opt.conv3d: self.load_network( self.netG_3d, 'G_3d', str(int(opt.which_epoch) + int(opt.which_epoch1))) self.load_network( self.netG, 'G', str(int(opt.which_epoch) + int(opt.which_epoch1))) self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1))) self.load_network(self.netDE1, 'DE1', str(int(opt.which_epoch1))) self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1))) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) if self.isTrain: if opt.continue_train: print("Load StyleNet from its pretrained model...") self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if self.isTrain: self.fake_AB1_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionL1 = torch.nn.L1Loss() self.criterionMSE = torch.nn.MSELoss() # initialize optimizers if self.opt.conv3d: self.optimizer_G_3d = torch.optim.Adam( self.netG_3d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.which_model_preNet != 'none': self.optimizer_preA = torch.optim.Adam( self.preNet_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') # if self.opt.conv3d: # networks.print_network(self.netG_3d) # networks.print_network(self.netG) # networks.print_network(self.netE1) # networks.print_network(self.netDE1) # if opt.which_model_preNet != 'none': # networks.print_network(self.preNet_A) # networks.print_network(self.netD1) print('-----------------------------------------------') self.initial = True def set_input(self, input): input_A0 = input['A'] input_B0 = input['B'] # print("stack gan input A0 size", input_A0.size()) # print("stack gan input B0 size", input_B0.size()) # print("StackGAN input keys:", input.keys()) self.input_A0.resize_(input_A0.size()).copy_(input_A0) self.input_B0.resize_(input_B0.size()).copy_(input_B0) self.image_paths = input['B_paths'] # print("stack gan self input A0 size", self.input_A0.size()) # print("stack gan self input B0 size", self.input_B0.size()) if self.opt.base_font: input_base = input['A_base'] self.input_base.resize_(input_base.size()).copy_(input_base) # print("stack gan self input base size", self.input_base.size()) b, c, m, n = self.input_base.size() real_base = self.Tensor(self.opt.output_nc, self.opt.input_nc_1, m, n) for batch in range(self.opt.output_nc): if not self.opt.rgb_in and self.opt.rgb_out: real_base[batch, 0, :, :] = self.input_base[0, batch, :, :] real_base[batch, 1, :, :] = self.input_base[0, batch, :, :] real_base[batch, 2, :, :] = self.input_base[0, batch, :, :] self.real_base = torch.tensor(real_base, requires_grad=False) # print("stack gan self real base size", self.real_base.size()) if self.opt.isTrain: self.id_ = {} # char to batch_id dict batch_id aka 0~ self.obs = [] # chars list for i, im in enumerate(self.image_paths): self.id_[int( im.split('/')[-1].split('.png')[0].split('_')[-1])] = i self.obs += [ int(im.split('/')[-1].split('.png')[0].split('_')[-1]) ] # if its not train char set to random batch id # 26 -> batch id for i in list(set(range(self.opt.output_nc)) - set(self.obs)): self.id_[i] = np.random.randint(low=0, high=len(self.image_paths)) self.num_disc = self.opt.output_nc + 1 def all2observed(self, tensor_all): b, c, m, n = self.real_A0.size() self.out_id = self.obs tensor_gt = self.Tensor(b, self.opt.input_nc_1, m, n) for batch in range(b): if not self.opt.rgb_in and self.opt.rgb_out: tensor_gt[batch, 0, :, :] = tensor_all.data[batch, self.out_id[batch], :, :] tensor_gt[batch, 1, :, :] = tensor_all.data[batch, self.out_id[batch], :, :] tensor_gt[batch, 2, :, :] = tensor_all.data[batch, self.out_id[batch], :, :] else: # TODO tensor_gt[batch, :, :, :] = tensor_all.data[ batch, self.out_id[batch] * np.array(self.opt.input_nc_1):(self.out_id[batch] + 1) * np.array(self.opt.input_nc_1), :, :] return tensor_gt def forward0(self): self.real_A0 = torch.tensor(self.input_A0) # print("stack gan self real A0 size", self.real_A0.size()) if self.opt.conv3d: self.real_A0_indep = self.netG_3d.forward( self.real_A0.unsqueeze(2)) # print("stack gan self real A0 indep size", self.real_A0_indep.size()) self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2)) # print("stack gan self fake B0 size", self.fake_B0.size()) else: self.fake_B0 = self.netG.forward(self.real_A0) # print("stack gan self fake B0 size", self.fake_B0.size()) if self.initial: if self.opt.orna: # False self.fake_B0_init = self.real_A0 else: self.fake_B0_init = self.fake_B0 # print("stack gan self fake B0 init", self.fake_B0_init.size()) def forward1(self, inp_grad=False): b, c, m, n = self.real_A0.size() self.batch_ = b self.out_id = self.obs real_A1 = self.Tensor(self.opt.output_nc, self.opt.input_nc_1, m, n) # 26 3 m n if self.opt.orna: inp_orna = self.fake_B0_init else: inp_orna = self.fake_B0 for batch in range(self.opt.output_nc): if not self.opt.rgb_in and self.opt.rgb_out: # print("sao operation 0") real_A1[batch, 0, :, :] = inp_orna.data[self.id_[batch], batch, :, :] real_A1[batch, 1, :, :] = inp_orna.data[self.id_[batch], batch, :, :] real_A1[batch, 2, :, :] = inp_orna.data[self.id_[batch], batch, :, :] else: # print("sao operation 1") # TODO real_A1[batch, :, :, :] = inp_orna.data[ batch, self.out_id[batch] * np.array(self.opt.input_nc_1):(self.out_id[batch] + 1) * np.array(self.opt.input_nc_1), :, :] if self.initial: self.real_A1_init = torch.tensor(real_A1, requires_grad=False) self.initial = False self.real_A1_s = torch.tensor(real_A1, requires_grad=inp_grad) self.real_A1 = self.real_A1_s # print("stack gan self real A1 size", self.real_A1.size()) self.fake_B1_emb = self.netE1.forward(self.real_A1) # print("stack gan self fake B1 emb size", self.fake_B1_emb.size()) self.fake_B1 = self.netDE1.forward(self.fake_B1_emb) # print("stack gan self fake B1 size", self.fake_B1.size()) self.real_B1 = torch.tensor(self.input_B0) # print("stack gan self real B1 size", self.real_B1.size()) self.real_A1_gt_s = torch.tensor(self.all2observed(inp_orna), requires_grad=True) self.real_A1_gt = (self.real_A1_gt_s) # print("stack gan self real A1 gt size", self.real_A1_gt.size()) self.fake_B1_gt_emb = self.netE1.forward(self.real_A1_gt) # print("stack gan self fake B1 gt emb size", self.fake_B1_gt_emb.size()) self.fake_B1_gt = self.netDE1.forward(self.fake_B1_gt_emb) # print("stack gan self fake B1 gt size", self.fake_B1_gt.size()) obs_ = torch.cuda.LongTensor( self.obs) if self.opt.gpu_ids else LongTensor(self.obs) if self.opt.base_font: real_base_gt = index_select(self.real_base, 0, obs_) self.real_base_gt = (torch.tensor(real_base_gt.data, requires_grad=False)) def add_noise_disc(self, real): # add noise to the discriminator target labels # real: True/False? if self.opt.noisy_disc: rand_lbl = random.random() if rand_lbl < 0.6: label = (not real) else: label = (real) else: label = (real) return label # no backprop gradients def test(self): with torch.no_grad(): self.real_A0 = self.input_A0 if self.opt.conv3d: self.real_A0_indep = self.netG_3d.forward( self.real_A0.unsqueeze(2)) self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2)) else: self.fake_B0 = self.netG.forward(self.real_A0) b, c, m, n = self.fake_B0.size() # for test time: we need to generate output for all of the glyphs in each input image if self.opt.rgb_in: self.batch_ = c / self.opt.input_nc_1 else: self.batch_ = c self.out_id = range(self.batch_) real_A1 = self.Tensor(self.batch_, self.opt.input_nc_1, m, n) if self.opt.orna: inp_orna = self.real_A0 else: inp_orna = self.fake_B0 for batch in range(self.batch_): if not self.opt.rgb_in and self.opt.rgb_out: real_A1[batch, 0, :, :] = inp_orna.data[:, self.out_id[batch], :, :] real_A1[batch, 1, :, :] = inp_orna.data[:, self.out_id[batch], :, :] real_A1[batch, 2, :, :] = inp_orna.data[:, self.out_id[batch], :, :] else: real_A1[batch, :, :, :] = inp_orna.data[:, self.out_id[ batch] * np.array(self.opt.input_nc_1):( self.out_id[batch] + 1) * np.array(self.opt.input_nc_1), :, :] self.real_A1 = real_A1 fake_B1_emb = self.netE1.forward(self.real_A1.detach()) self.fake_B1 = self.netDE1.forward(fake_B1_emb) self.real_B1 = self.input_B0 # get image paths def get_image_paths(self): return self.image_paths def prepare_data(self): if self.opt.conditional: if self.opt.base_font: self.first_pair = self.real_base self.first_pair_gt = self.real_base_gt else: self.first_pair = torch.tensor(self.real_A1.data, requires_grad=False) self.first_pair_gt = torch.tensor(self.real_A1_gt.data, requires_grad=False) def backward_D1(self): b, c, m, n = self.fake_B1.size() # Fake # stop backprop to the generator by detaching fake_B label_fake = self.add_noise_disc(False) if self.opt.conditional: fake_AB1 = self.fake_AB1_pool.query( torch.cat((self.first_pair, self.fake_B1), 1)) self.pred_fake1 = self.netD1.forward(fake_AB1.detach()) if self.opt.which_model_preNet != 'none': # transform the input transformed_AB1 = self.preNet_A.forward(fake_AB1.detach()) self.pred_fake_GL = self.netD1.forward(transformed_AB1) self.loss_D1_fake = 0 self.loss_D1_fake += self.criterionGAN(self.pred_fake1, label_fake) if self.opt.which_model_preNet != 'none': self.loss_D1_fake += self.criterionGAN(self.pred_fake_GL, label_fake) # Real label_real = self.add_noise_disc(True) if self.opt.conditional: real_AB1 = torch.cat((self.first_pair_gt, self.real_B1), 1).detach() self.pred_real1 = self.netD1.forward(real_AB1) if self.opt.which_model_preNet != 'none': transformed_real_AB1 = self.preNet_A.forward(real_AB1) self.pred_real1_GL = self.netD1.forward(transformed_real_AB1) self.loss_D1_real = 0 self.loss_D1_real += self.criterionGAN(self.pred_real1, label_real) if self.opt.which_model_preNet != 'none': self.loss_D1_real += self.criterionGAN(self.pred_real1_GL, label_real) # Combined loss self.loss_D1 = (self.loss_D1_fake + self.loss_D1_real) * 0.5 self.loss_D1.backward() def backward_G(self, pass_grad, iter): b, c, m, n = self.fake_B0.size() if not self.opt.lambda_C or (iter > 700): self.loss_G_L1 = torch.tensor(torch.zeros(1)) else: weight_val = 10.0 weights = torch.ones(b, c, m, n).cuda() if self.opt.gpu_ids else torch.ones( b, c, m, n) obs_ = torch.cuda.LongTensor( self.obs) if self.opt.gpu_ids else LongTensor(self.obs) weights.index_fill_(1, obs_, weight_val) weights = torch.tensor(weights, requires_grad=False) self.loss_G_L1 = self.criterionL1(weights * self.fake_B0, weights * self.fake_B0_init.detach()) * \ self.opt.lambda_C self.loss_G_L1.backward(retain_graph=True) self.fake_B0.backward(pass_grad) def backward_G1(self, iter): # First, G(A) should fake the discriminator if self.opt.conditional: fake_AB = torch.cat((self.first_pair.detach(), self.fake_B1), 1) pred_fake = self.netD1.forward(fake_AB) if self.opt.which_model_preNet != 'none': # transform the input transformed_AB1 = self.preNet_A.forward(fake_AB) pred_fake_GL = self.netD1.forward(transformed_AB1) self.loss_G1_GAN = 0 self.loss_G1_GAN += self.criterionGAN(pred_fake, True) if self.opt.which_model_preNet != 'none': self.loss_G1_GAN += self.criterionGAN(pred_fake_GL, True) # print("backward G1 self fake_B1_gt size", self.fake_B1_gt.size()) # print("backward G1 self real_B1 size", self.real_B1.size()) self.loss_G1_L1 = self.criterionL1(self.fake_B1_gt, self.real_B1) * self.opt.lambda_A fake_B1_gray = 1 - torch.nn.functional.sigmoid( 100 * (torch.mean(self.fake_B1, dim=1, keepdim=True) - 0.9)) real_A1_gray = 1 - torch.nn.functional.sigmoid( 100 * (torch.mean(self.real_A1, dim=1, keepdim=True) - 0.9)) self.loss_G1_MSE_rgb2gay = self.criterionMSE( fake_B1_gray, real_A1_gray.detach()) * self.opt.lambda_A / 3.0 real_A1_gt_gray = 1 - torch.nn.functional.sigmoid( 100 * (torch.mean(self.real_A1_gt, dim=1, keepdim=True) - 0.9)) real_B1_gray = 1 - torch.nn.functional.sigmoid( 100 * (torch.mean(self.real_B1, dim=1, keepdim=True) - 0.9)) self.loss_G1_MSE_gt = self.criterionMSE( real_A1_gt_gray, real_B1_gray) * self.opt.lambda_A # update generator less frequently if iter < 200: rate_gen = 90 else: rate_gen = 60 if (iter % rate_gen) == 0: self.loss_G1 = self.loss_G1_GAN + self.loss_G1_L1 + self.loss_G1_MSE_gt G1_L1_update = True # G1_GAN_update = True else: self.loss_G1 = self.loss_G1_L1 + self.loss_G1_MSE_gt G1_L1_update = True # G1_GAN_update = False if (iter < 200): self.loss_G1 += self.loss_G1_MSE_rgb2gay else: self.loss_G1 += 0.01 * self.loss_G1_MSE_rgb2gay self.loss_G1.backward(retain_graph=True) (b, c, m, n) = self.real_A1_s.size() self.real_A1_grad = torch.zeros( b, c, m, n).cuda() if self.opt.gpu_ids else torch.zeros( b, c, m, n) if G1_L1_update: for batch in self.obs: self.real_A1_grad[ batch, :, :, :] = self.real_A1_gt_s.grad.data[ self.id_[batch], :, :, :] def optimize_parameters(self, iter): self.forward0() self.forward1(inp_grad=True) self.prepare_data() if self.opt.which_model_preNet != 'none': self.optimizer_preA.zero_grad() self.optimizer_D1.zero_grad() self.backward_D1() self.optimizer_D1.step() if self.opt.which_model_preNet != 'none': self.optimizer_preA.step() self.optimizer_E1.zero_grad() self.optimizer_DE1.zero_grad() self.backward_G1(iter) self.optimizer_DE1.step() self.optimizer_E1.step() self.loss_G_L1 = torch.tensor(torch.zeros(1)) def optimize_parameters_Stacked(self, iter): self.forward0() self.forward1(inp_grad=True) self.prepare_data() if self.opt.which_model_preNet != 'none': self.optimizer_preA.zero_grad() self.optimizer_D1.zero_grad() self.backward_D1() self.optimizer_D1.step() if self.opt.which_model_preNet != 'none': self.optimizer_preA.step() self.optimizer_E1.zero_grad() self.optimizer_DE1.zero_grad() self.backward_G1(iter) self.optimizer_DE1.step() self.optimizer_E1.step() b, c, m, n = self.fake_B0.size() self.optimizer_G.zero_grad() if self.opt.conv3d: self.optimizer_G_3d.zero_grad() b, c, m, n = self.fake_B0.size() fake_B0_grad = torch.zeros( b, c, m, n).cuda() if self.opt.gpu_ids else torch.zeros( b, c, m, n) real_A_grad = self.real_A1_grad for batch in range(self.opt.input_nc): if not self.opt.rgb_in and self.opt.rgb_out: fake_B0_grad[self.id_[batch], batch, :, :] += torch.mean( real_A_grad[batch, :, :, :], 0) * 3 else: # TODO fake_B0_grad[ batch, self.obs[batch] * np.array(self.opt.input_nc_1):(self.obs[batch] + 1) * np.array(self.opt.input_nc_1), :, :] = real_A_grad[ batch, :, :, :] self.backward_G(fake_B0_grad, iter) self.optimizer_G.step() if self.opt.conv3d: self.optimizer_G_3d.step() def get_current_errors(self): return OrderedDict([('G1_GAN', self.loss_G1_GAN.item()), ('G1_L1', self.loss_G1_L1.item()), ('G1_MSE_gt', self.loss_G1_MSE_gt.item()), ('G1_MSE', self.loss_G1_MSE_rgb2gay.item()), ('D1_real', self.loss_D1_real.item()), ('D1_fake', self.loss_D1_fake.item()), ('G_L1', self.loss_G_L1.item())]) def get_current_visuals(self): real_A1 = self.real_A1.data.clone() g, c, m, n = real_A1.size() fake_B = self.fake_B1.data.clone() real_B = self.real_B1.data.clone() if self.opt.isTrain: real_A_all = real_A1 fake_B_all = fake_B else: real_A_all = self.Tensor(real_B.size(0), real_B.size(1), real_A1.size(2), real_A1.size(2) * real_A1.size(0)) fake_B_all = self.Tensor(real_B.size(0), real_B.size(1), real_A1.size(2), fake_B.size(2) * fake_B.size(0)) for b in range(g): real_A_all[:, :, :, self.out_id[b] * m:m * (self.out_id[b] + 1)] = real_A1[b, :, :, :] fake_B_all[:, :, :, self.out_id[b] * m:m * (self.out_id[b] + 1)] = fake_B[b, :, :, :] real_A = util.tensor2im(real_A_all) fake_B = util.tensor2im(fake_B_all) real_B = util.tensor2im(self.real_B1.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): if not self.opt.no_Style2Glyph: try: G_label = str(int(label) + int(self.opt.which_epoch)) except Exception: G_label = label if self.opt.conv3d: self.save_network(self.netG_3d, 'G_3d', G_label, self.gpu_ids) self.save_network(self.netG, 'G', G_label, self.gpu_ids) self.save_network(self.netE1, 'E1', label, self.gpu_ids) self.save_network(self.netDE1, 'DE1', label, self.gpu_ids) self.save_network(self.netD1, 'D1', label, self.gpu_ids) if self.opt.which_model_preNet != 'none': self.save_network(self.preNet_A, 'PRE_A', label, gpu_ids=self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd if self.opt.which_model_preNet != 'none': for param_group in self.optimizer_preA.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D1.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr for param_group in self.optimizer_E1.param_groups: param_group['lr'] = lr for param_group in self.optimizer_DE1.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
class ThreeLayersSeparateModel(BaseModel): def name(self): return 'ThreeLayersSeparateModel' @staticmethod def modify_commandline_options(parser, is_train=True): # changing the default values to match the pix2pix paper # (https://phillipi.github.io/pix2pix/) parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') parser.set_defaults(dataset_mode='aligned') parser.set_defaults(netG='unet_256') if is_train: parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_A', 'G_B', 'G_C'] # ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = [ 'rgb_img', 'im1', 'im2', 'chrom', 'predication', 'shading1', 'shading2', 'est_im1', 'est_im2' ] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'G_C'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B', 'G_C'] # load/define networks self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, "resnet_9blocks", opt.norm, not opt.no_dropout, "kaiming", opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, "upunet_256", opt.norm, not opt.no_dropout, "kaiming", opt.init_gain, self.gpu_ids) self.netG_C = networks.define_G(opt.input_nc * 2, opt.output_nc, opt.ngf, "render", opt.norm, not opt.no_dropout, "kaiming", opt.init_gain, self.gpu_ids) """ if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) """ if self.isTrain: self.image_pool = ImagePool(opt.pool_size) self.image_pool1 = ImagePool(opt.pool_size) self.image_pool2 = ImagePool(opt.pool_size) # define loss functions # self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) # self.criterionL1 = torch.nn.L1Loss() self.loss = networks.JointLoss() self.sloss = networks.ShadingLoss() self.rloss = networks.ReconstructionLoss() #self.gloss = networks.L1Loss() # initialize optimizers self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lrA, betas=(opt.beta1, 0.999)) self.optimizer_G_B = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lrB, betas=(opt.beta1, 0.999)) self.optimizer_G_C = torch.optim.Adam(self.netG_C.parameters(), lr=opt.lrB, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G_B) self.optimizers.append(self.optimizer_G_C) def set_input(self, input): self.rgb_img = input['rgb_img'].to(self.device) self.chrom = input['chrom'].to(self.device) # self.gamma = input['gamma'].to(self.device) self.image_paths = input['A_paths'] self.mask = input['mask'].to(self.device) self.im1 = input['im1'].to(self.device) self.im2 = input['im2'].to(self.device) self.img1 = input['img1'].to(self.device) self.img2 = input['img2'].to(self.device) # self.img_wb = input['img_wb'].to(self.device) def forward(self): self.predication = self.netG_A(self.rgb_img) inputG = torch.cat((self.predication, self.rgb_img), 1) self.shading1, self.shading2 = self.netG_B(inputG) input_ = torch.cat((self.rgb_img, self.shading1), 1) self.est_im1 = self.netG_C(input_) input_ = torch.cat((self.rgb_img, self.shading2), 1) self.est_im2 = self.netG_C(input_) #self.est_im1 , self.est_im2 = est_imgs[:,:3,:,:], est_imgs[:,3:,:,:] def L1Loss(self, prediction, gt, mask): num_valid = torch.sum(mask) diff = torch.mul(mask, torch.abs(prediction - gt)) return torch.sum(diff) / num_valid def backward_G_C(self): input_G_B1 = self.image_pool1.query( torch.cat((self.rgb_img, self.shading1), 1)) est_im1 = self.netG_C(input_G_B1.detach()) input_G_B2 = self.image_pool2.query( torch.cat((self.rgb_img, self.shading2), 1)) est_im2 = self.netG_C(input_G_B2.detach()) if self.L1Loss(self.shading1, self.im1, self.mask) < self.L1Loss( self.shading1, self.im2, self.mask): input_GT1 = torch.cat((self.rgb_img, self.im1), 1) input_GT2 = torch.cat((self.rgb_img, self.im2), 1) else: input_GT1 = torch.cat((self.rgb_img, self.im2), 1) input_GT2 = torch.cat((self.rgb_img, self.im1), 1) gt_im1 = self.netG_C(input_GT1) gt_im2 = self.netG_C(input_GT2) #gt_im1, gt_im2 = gt_imgs[:,:3,:,:], gt_imgs[:,3:,:,:] img = est_im1 + est_im2 gt_img = gt_im1 + gt_im2 self.loss_G_C = .5 * self.rloss(self.img1, self.img2, est_im1, est_im2, self.mask) + \ .5 * self.rloss(self.img1, self.img2, gt_im1, gt_im2, self.mask) + \ .5 * self.loss(self.rgb_img, img, self.mask) + \ .5 * self.loss(self.rgb_img, gt_img, self.mask) self.loss_G_C.backward() def backward_G_B(self): input_G_B = self.image_pool.query( torch.cat((self.predication, self.rgb_img), 1)) input_G_T = torch.cat((self.chrom, self.rgb_img), 1) shading1, shading2 = self.netG_B(input_G_B.detach()) gt_shading1, gt_shading2 = self.netG_B(input_G_T) # self.shading1, self.shading2 = self.netG_B(input_G_B.detach()) self.loss_G_B = .5 * self.sloss(self.im1, self.im2, shading1, shading2, self.mask) \ +.5 * self.sloss(self.im1, self.im2, gt_shading1, gt_shading2, self.mask) \ self.loss_G_B.backward() def backward_G(self): # First, G(A) should fake the discriminator self.loss_G_A = self.loss(self.chrom, self.predication, self.mask) self.loss_G = self.loss_G_A self.loss_G.backward() def optimize_parameters(self): """ self.forward() # update D self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G self.set_requires_grad(self.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() """ self.forward() # update G_C self.set_requires_grad(self.netG_C, True) self.optimizer_G_C.zero_grad() self.backward_G_C() self.optimizer_G_C.step() self.set_requires_grad(self.netG_C, False) # update G_B self.set_requires_grad(self.netG_B, True) self.optimizer_G_B.zero_grad() self.backward_G_B() self.optimizer_G_B.step() self.set_requires_grad(self.netG_B, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
display_freq = 10000 netG_A_function = get_generater_function(netG_A) netG_B_functionr = get_generater_function(netG_B) fake_A_pool = ImagePool() fake_B_pool = ImagePool() while epoch_count < how_many_epochs: target_label = np.zeros((batch_size, 1)) epoch_count, A, B = next(train_batch) tmp_fake_B = netG_A_function([A])[0] tmp_fake_A = netG_B_functionr([B])[0] _fake_B = fake_B_pool.query(tmp_fake_B) _fake_A = fake_A_pool.query(tmp_fake_A) netG_train_function.train_on_batch([A, B], target_label) netD_B_train_function.train_on_batch([B, _fake_B], target_label) netD_A_train_function.train_on_batch([A, _fake_A], target_label) iteration_count += 1 if iteration_count % display_freq == 0: clear_output() traintime = (time.time() - time_start) / iteration_count print('epoch_count: {} iter_count: {} timecost/iter: {}s'.format(epoch_count, iteration_count, traintime)) _, val_A, val_B = next(val_batch) show_generator_image(val_A, val_B, netG_A, netG_B)
class GcGANMixModel(BaseModel): def name(self): return 'GcGANMixModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) self.netG_AB = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_rot_B = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_vf_B = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_AB, 'G_AB', which_epoch) if self.isTrain: self.load_network(self.netD_B, 'D_B', which_epoch) self.load_network(self.netD_rot_B, 'D_rot_B', which_epoch) self.load_network(self.netD_vf_B, 'D_vf_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_B_pool = ImagePool(opt.pool_size) self.fake_rot_B_pool = ImagePool(opt.pool_size) self.fake_vf_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionIdt = torch.nn.L1Loss() self.criterionGc = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_AB.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(itertools.chain( self.netD_B.parameters(), self.netD_rot_B.parameters(), self.netD_vf_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_AB) if self.isTrain: networks.print_network(self.netD_B) networks.print_network(self.netD_rot_B) networks.print_network(self.netD_vf_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def backward_D_basic(self, netD, real, fake, netD_rot, real_rot, fake_rot, netD_vf, real_vf, fake_vf): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # Real_gc pred_real_rot = netD_rot(real_rot) loss_D_rot_real = self.criterionGAN(pred_real_rot, True) # Fake_gc pred_fake_rot = netD_rot(fake_rot.detach()) loss_D_rot_fake = self.criterionGAN(pred_fake_rot, False) # Combined loss loss_D += (loss_D_rot_real + loss_D_rot_fake) * 0.5 # Real_gc pred_real_vf = netD_vf(real_vf) loss_D_vf_real = self.criterionGAN(pred_real_vf, True) # Fake_gc pred_fake_vf = netD_vf(fake_vf.detach()) loss_D_vf_fake = self.criterionGAN(pred_fake_vf, False) # Combined loss loss_D += (loss_D_vf_real + loss_D_vf_fake) * 0.5 # backward loss_D.backward() return loss_D def get_image_paths(self): return self.image_paths def rot90(self, tensor, direction): tensor = tensor.transpose(2, 3) size = self.opt.fineSize inv_idx = torch.arange(size - 1, -1, -1).long().cuda() if direction == 0: tensor = torch.index_select(tensor, 3, inv_idx) else: tensor = torch.index_select(tensor, 2, inv_idx) return tensor def forward(self): input_A = self.input_A.clone() input_B = self.input_B.clone() self.real_A = self.input_A self.real_B = self.input_B size = self.opt.fineSize #self.mix_geo = np.random.choice(['rot', 'vf']) self.real_rot_A = self.rot90(input_A.clone(), 0) self.real_rot_B = self.rot90(input_B.clone(), 0) inv_idx = torch.arange(size - 1, -1, -1).long().cuda() self.real_vf_A = torch.index_select(input_A.clone(), 2, inv_idx) self.real_vf_B = torch.index_select(input_B.clone(), 2, inv_idx) def get_gc_rot_loss(self, AB, AB_gc, direction): loss_gc = 0.0 if direction == 0: AB_gt = self.rot90(AB_gc.clone().detach(), 1) loss_gc = self.criterionGc(AB, AB_gt) AB_gc_gt = self.rot90(AB.clone().detach(), 0) loss_gc += self.criterionGc(AB_gc, AB_gc_gt) else: AB_gt = self.rot90(AB_gc.clone().detach(), 0) loss_gc = self.criterionGc(AB, AB_gt) AB_gc_gt = self.rot90(AB.clone().detach(), 1) loss_gc += self.criterionGc(AB_gc, AB_gc_gt) loss_gc = loss_gc * self.opt.lambda_AB * self.opt.lambda_gc #loss_gc = loss_gc*self.opt.lambda_AB return loss_gc def get_gc_vf_loss(self, AB, AB_gc): loss_gc = 0.0 size = self.opt.fineSize inv_idx = torch.arange(size - 1, -1, -1).long().cuda() AB_gt = torch.index_select(AB_gc.clone().detach(), 2, inv_idx) loss_gc = self.criterionGc(AB, AB_gt) AB_gc_gt = torch.index_select(AB.clone().detach(), 2, inv_idx) loss_gc += self.criterionGc(AB_gc, AB_gc_gt) loss_gc = loss_gc * self.opt.lambda_AB * self.opt.lambda_gc #loss_gc = loss_gc*self.opt.lambda_AB return loss_gc def backward_D_B(self): fake_B = self.fake_B_pool.query(self.fake_B) fake_rot_B = self.fake_rot_B_pool.query(self.fake_rot_B) fake_vf_B = self.fake_vf_B_pool.query(self.fake_vf_B) loss_D_B = self.backward_D_basic(self.netD_B, self.real_B, fake_B, self.netD_rot_B, self.real_rot_B, fake_rot_B, self.netD_vf_B, self.real_vf_B, fake_vf_B) self.loss_D_B = loss_D_B.item() def backward_G(self): # adversariasl loss fake_B = self.netG_AB.forward(self.real_A) pred_fake = self.netD_B.forward(fake_B) loss_G_AB = self.criterionGAN(pred_fake, True) * self.opt.lambda_G fake_rot_B = self.netG_AB.forward(self.real_rot_A) pred_fake = self.netD_rot_B.forward(fake_rot_B) loss_G_gc_AB = self.criterionGAN(pred_fake, True) * self.opt.lambda_G fake_vf_B = self.netG_AB.forward(self.real_vf_A) pred_fake = self.netD_vf_B.forward(fake_vf_B) loss_G_gc_AB += self.criterionGAN(pred_fake, True) * self.opt.lambda_G loss_G_gc_AB = loss_G_gc_AB * 0.5 loss_gc = self.get_gc_rot_loss(fake_B, fake_rot_B, 0) loss_gc += self.get_gc_vf_loss(fake_B, fake_vf_B) loss_gc = loss_gc * 0.5 if self.opt.identity > 0: # G_AB should be identity if real_B is fed. idt_A = self.netG_AB(self.real_B) loss_idt = self.criterionIdt( idt_A, self.real_B) * self.opt.lambda_AB * self.opt.identity idt_gc_A = self.netG_AB(self.real_rot_B) loss_idt_gc = self.criterionIdt( idt_gc_A, self.real_rot_B) * self.opt.lambda_AB * self.opt.identity idt_gc_A = self.netG_AB(self.real_vf_B) loss_idt_gc += self.criterionIdt( idt_gc_A, self.real_vf_B) * self.opt.lambda_AB * self.opt.identity loss_idt_gc = loss_idt_gc * 0.5 self.idt_A = idt_A.data self.idt_gc_A = idt_gc_A.data self.loss_idt = loss_idt.item() self.loss_idt_gc = loss_idt_gc.item() else: loss_idt = 0 loss_idt_gc = 0 self.loss_idt = 0 self.loss_idt_gc = 0 loss_G = loss_G_AB + loss_G_gc_AB + loss_gc + loss_idt + loss_idt_gc loss_G.backward() self.fake_B = fake_B.data self.fake_rot_B = fake_rot_B.data self.fake_vf_B = fake_vf_B.data self.loss_G_AB = loss_G_AB.item() self.loss_G_gc_AB = loss_G_gc_AB.item() self.loss_gc = loss_gc.item() def optimize_parameters(self): # forward self.forward() # G_AB self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_B and D_gc_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_B', self.loss_D_B), ('G_AB', self.loss_G_AB), ('Gc', self.loss_gc), ('G_gc_AB', self.loss_G_gc_AB)]) if self.opt.identity > 0.0: ret_errors['idt'] = self.loss_idt ret_errors['idt_gc'] = self.loss_idt_gc return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) real_B = util.tensor2im(self.real_B.data) fake_B = util.tensor2im(self.fake_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) return ret_visuals def save(self, label): self.save_network(self.netG_AB, 'G_AB', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) self.save_network(self.netD_rot_B, 'D_rot_B', label, self.gpu_ids) self.save_network(self.netD_vf_B, 'D_vf_B', label, self.gpu_ids) def test(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) self.fake_B = self.netG_AB.forward(self.real_A).data
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class DAnetmodel(BaseModel): def name(self): return 'DAnetModel' @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(no_dropout=True) if is_train: parser.add_argument( '--lambda_Dehazing', type=float, default=10.0, help='weight for reconstruction loss (dehazing)') parser.add_argument('--lambda_Dehazing_Con', type=float, default=50.0, help='weight for consistency') parser.add_argument('--lambda_Dehazing_DC', type=float, default=0.01, help='weight for dark channel loss') parser.add_argument('--lambda_Dehazing_TV', type=float, default=0.01, help='weight for TV loss') parser.add_argument('--lambda_gan_feat', type=float, default=0.1, help='weight for feature GAN loss') # cyclegan parser.add_argument('--lambda_S', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_R', type=float, default=1.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=30.0, help= 'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument('--which_model_netG_A', type=str, default='resnet_9blocks', help='selects model to use for netG_A') parser.add_argument('--which_model_netG_B', type=str, default='resnet_9blocks', help='selects model to use for netG_B') parser.add_argument('--S_Dehazing_premodel', type=str, default=" ", help='pretrained dehazing model') parser.add_argument('--R_Dehazing_premodel', type=str, default=" ", help='pretrained dehazing model') parser.add_argument('--g_s2r_premodel', type=str, default=" ", help='pretrained G_s2r model') parser.add_argument('--g_r2s_premodel', type=str, default=" ", help='pretrained G_r2s model') parser.add_argument('--d_s_premodel', type=str, default=" ", help='pretrained D_s model') parser.add_argument('--d_r_premodel', type=str, default=" ", help='pretrained D_r model') parser.add_argument('--freeze_bn', action='store_true', help='freeze the bn in mde') parser.add_argument('--freeze_in', action='store_true', help='freeze the in in cyclegan') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses if self.isTrain: self.loss_names = [ 'S2R_Dehazing', 'S_Dehazing', 'R2S_Dehazing_DC', 'R_Dehazing_DC' ] self.loss_names += [ 'R2S_Dehazing_TV', 'R_Dehazing_TV', 'Dehazing_Con' ] self.loss_names += [ 'idt_R', 'idt_S', 'D_R', 'D_S', 'G_S2R', 'G_R2S', 'cycle_S', 'cycle_R', 'G_Rfeat', 'G_Sfeat', 'D_Rfeat', 'D_Sfeat' ] # specify the images you want to save/display. The program will call base_model.get_current_visuals if self.isTrain: visual_names_S = [ 'syn_haze_img', 'img_s2r', 'clear_img', 's2r_dehazing_img', 's_dehazing_img' ] #, 's_rec_img'] visual_names_R = [ 'real_haze_img', 'img_r2s', 'r2s_dehazing_img', 'r_dehazing_img' ] #, 'r_rec_img'] # if self.opt.lambda_identity > 0.0: # visual_names_S.append('idt_S') # visual_names_R.append('idt_R') self.visual_names = visual_names_S + visual_names_R else: self.visual_names = ['pred', 'img', 'img_trans'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['S_Dehazing', 'R_Dehazing'] self.model_names += [ 'S2R', 'R2S', 'D_R', 'D_S', 'D_Sfeat', 'D_Rfeat' ] else: self.model_names = ['S_Dehazing', 'R_Dehzaing', 'S2R', 'R2S'] # Temp Fix for nn.parallel as nn.parallel crashes oc calculating gradient penalty # use_parallel = not opt.gan_type == 'wgan-gp' use_parallel = False # define the transform network self.netS2R = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG_A, opt.norm, not opt.no_dropout, self.gpu_ids, use_parallel, opt.learn_residual) self.netR2S = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG_A, opt.norm, not opt.no_dropout, self.gpu_ids, use_parallel, opt.learn_residual) # define the image dehazing network self.netR_Dehazing = networks.define_Gen( opt.input_nc, opt.output_nc, opt.ngf, opt.task_layers, opt.norm, opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) self.netS_Dehazing = networks.define_Gen( opt.input_nc, opt.output_nc, opt.ngf, opt.task_layers, opt.norm, opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) # define the discriminator if self.isTrain: use_sigmoid = False self.netD_R = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, use_parallel) self.netD_S = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, use_parallel) self.netD_Sfeat = networks.define_featureD(opt.image_feature, n_layers=2, norm='batch', activation='PReLU', init_type='xavier', gpu_ids=self.gpu_ids) self.netD_Rfeat = networks.define_featureD(opt.image_feature, n_layers=2, norm='batch', activation='PReLU', init_type='xavier', gpu_ids=self.gpu_ids) if self.isTrain and not opt.continue_train: self.init_with_pretrained_model('S2R', self.opt.g_s2r_premodel) self.init_with_pretrained_model('R2S', self.opt.g_r2s_premodel) self.init_with_pretrained_model('R_Dehazing', self.opt.R_Dehazing_premodel) self.init_with_pretrained_model('S_Dehazing', self.opt.S_Dehazing_premodel) self.init_with_pretrained_model('D_R', self.opt.d_r_premodel) self.init_with_pretrained_model('D_S', self.opt.d_s_premodel) if opt.continue_train: self.load_networks(opt.which_epoch) if self.isTrain: self.fake_s_pool = ImagePool(opt.pool_size) self.fake_r_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = losses.GANLoss(use_ls=not opt.no_lsgan).to( self.device) self.l1loss = torch.nn.L1Loss() self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionDehazing = torch.nn.MSELoss() self.criterionCons = torch.nn.L1Loss() self.nonlinearity = torch.nn.ReLU() self.TVLoss = L1_TVLoss_Charbonnier() # initialize optimizers self.optimizer_G_task = torch.optim.Adam(itertools.chain( self.netS_Dehazing.parameters(), self.netR_Dehazing.parameters()), lr=opt.lr_task, betas=(0.95, 0.999)) self.optimizer_G_trans = torch.optim.Adam(itertools.chain( self.netS2R.parameters(), self.netR2S.parameters()), lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_S.parameters(), self.netD_R.parameters(), self.netD_Sfeat.parameters(), self.netD_Rfeat.parameters()), lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizers = [] self.optimizers.append(self.optimizer_G_task) self.optimizers.append(self.optimizer_G_trans) self.optimizers.append(self.optimizer_D) if opt.freeze_bn: self.netS_Dehazing.apply(networks.freeze_bn) self.netR_Dehazing.apply(networks.freeze_bn) if opt.freeze_in: self.netS2R.apply(networks.freeze_in) self.netR2S.apply(networks.freeze_in) def set_input(self, input): if self.isTrain: AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] input_C = input['C'] self.syn_haze_img = input_A.to(self.device) self.real_haze_img = input_C.to(self.device) self.clear_img = input_B.to(self.device) #self.depth = input['D'].to(self.device) #self.real_depth = input['E'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] else: self.img = input['A'].to(self.device) def forward(self): if self.isTrain: pass # else: # if self.opt.phase == 'test': # self.pred_s = self.netS_Dehazing(self.img)[-1] # self.img_trans = self.netS2R(self.img) # self.pred_r = self.netR_Dehazing(self.img_trans)[-1] # self.pred = 0.5 * (self.pred_s + self.pred_r) # else: # self.pred_r = self.netR_Dehazing(self.img)[-1] # self.img_trans = self.netR2S(self.img) # self.pred_s = self.netS_Dehazing(self.img_trans)[-1] # self.pred = 0.5 * (self.pred_s + self.pred_r) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real.detach()) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_S(self): img_r2s = self.fake_s_pool.query(self.img_r2s) self.loss_D_S = self.backward_D_basic(self.netD_S, self.syn_haze_img, img_r2s) def backward_D_R(self): img_s2r = self.fake_r_pool.query(self.img_s2r) self.loss_D_R = self.backward_D_basic(self.netD_R, self.real_haze_img, img_s2r) def backward_D_Sfeat(self): self.loss_D_Sfeat = self.backward_D_basic(self.netD_Sfeat, self.s_dehazing_feat, self.r2s_dehazing_feat) def backward_D_Rfeat(self): self.loss_D_Rfeat = self.backward_D_basic(self.netD_Rfeat, self.r_dehazing_feat, self.s2r_dehazing_feat) def backward_G(self): lambda_Dehazing = self.opt.lambda_Dehazing lambda_Dehazing_Con = self.opt.lambda_Dehazing_Con lambda_gan_feat = self.opt.lambda_gan_feat lambda_idt = self.opt.lambda_identity lambda_S = self.opt.lambda_S lambda_R = self.opt.lambda_R # =========================== synthetic ========================== self.img_s2r = self.netS2R(self.syn_haze_img) self.idt_S = self.netR2S(self.syn_haze_img) self.s_rec_img = self.netR2S(self.img_s2r) self.out_r = self.netR_Dehazing(self.img_s2r) self.out_s = self.netS_Dehazing(self.syn_haze_img) self.s2r_dehazing_feat = self.out_r[0] self.s_dehazing_feat = self.out_s[0] self.s2r_dehazing_img = self.out_r[-1] self.s_dehazing_img = self.out_s[-1] self.loss_G_S2R = self.criterionGAN(self.netD_R(self.img_s2r), True) self.loss_G_Rfeat = self.criterionGAN( self.netD_Rfeat(self.s2r_dehazing_feat), True) * lambda_gan_feat self.loss_cycle_S = self.criterionCycle(self.s_rec_img, self.syn_haze_img) * lambda_S self.loss_idt_S = self.criterionIdt( self.idt_S, self.syn_haze_img) * lambda_S * lambda_idt size = len(self.out_s) self.loss_S_Dehazing = 0.0 clear_imgs = task.scale_pyramid(self.clear_img, size - 1) for (s_dehazing_img, clear_img) in zip(self.out_s[1:], clear_imgs): self.loss_S_Dehazing += self.criterionDehazing( s_dehazing_img, clear_img) * lambda_Dehazing self.loss_S2R_Dehazing = 0.0 for (s2r_dehazing_img, clear_img) in zip(self.out_r[1:], clear_imgs): self.loss_S2R_Dehazing += self.criterionDehazing( s2r_dehazing_img, clear_img) * lambda_Dehazing self.loss = self.loss_G_S2R + self.loss_G_Rfeat + self.loss_cycle_S + self.loss_idt_S + self.loss_S_Dehazing + self.loss_S2R_Dehazing self.loss.backward() # ============================= real ============================= self.img_r2s = self.netR2S(self.real_haze_img) self.idt_R = self.netS2R(self.real_haze_img) self.r_rec_img = self.netS2R(self.img_r2s) self.out_s = self.netS_Dehazing(self.img_r2s) self.out_r = self.netR_Dehazing(self.real_haze_img) self.r_dehazing_feat = self.out_r[0] self.r2s_dehazing_feat = self.out_s[0] self.r_dehazing_img = self.out_r[-1] self.r2s_dehazing_img = self.out_s[-1] self.loss_G_R2S = self.criterionGAN(self.netD_S(self.img_r2s), True) self.loss_G_Sfeat = self.criterionGAN( self.netD_Sfeat(self.r2s_dehazing_feat), True) * lambda_gan_feat self.loss_cycle_R = self.criterionCycle(self.r_rec_img, self.real_haze_img) * lambda_R self.loss_idt_R = self.criterionIdt( self.idt_R, self.real_haze_img) * lambda_R * lambda_idt # TV LOSS self.loss_R2S_Dehazing_TV = self.TVLoss( self.r2s_dehazing_img) * self.opt.lambda_Dehazing_TV self.loss_R_Dehazing_TV = self.TVLoss( self.r_dehazing_img) * self.opt.lambda_Dehazing_TV # DC LOSS self.loss_R2S_Dehazing_DC = DCLoss( (self.r2s_dehazing_img + 1) / 2, self.opt.patch_size) * self.opt.lambda_Dehazing_DC self.loss_R_Dehazing_DC = DCLoss( (self.r_dehazing_img + 1) / 2, self.opt.patch_size) * self.opt.lambda_Dehazing_DC # dehazing consistency self.loss_Dehazing_Con = 0.0 for (out_s1, out_r2) in zip(self.out_s, self.out_r): self.loss_Dehazing_Con += self.criterionCons( out_s1, out_r2) * lambda_Dehazing_Con self.loss_G = self.loss_G_R2S + self.loss_G_Sfeat + self.loss_cycle_R + self.loss_idt_R + self.loss_R2S_Dehazing_TV \ + self.loss_R_Dehazing_TV + self.loss_R2S_Dehazing_DC + self.loss_R_Dehazing_DC + self.loss_Dehazing_Con self.loss_G.backward() self.real_dehazing_img = (self.r_dehazing_img + self.r2s_dehazing_img) / 2.0 self.syn_dehazing_img = (self.s_dehazing_img + self.s2r_dehazing_img) / 2.0 def optimize_parameters(self): self.forward() self.set_requires_grad( [self.netD_S, self.netD_R, self.netD_Sfeat, self.netD_Rfeat], False) self.optimizer_G_trans.zero_grad() self.optimizer_G_task.zero_grad() self.backward_G() self.optimizer_G_trans.step() self.optimizer_G_task.step() self.set_requires_grad( [self.netD_S, self.netD_R, self.netD_Sfeat, self.netD_Rfeat], True) self.optimizer_D.zero_grad() self.backward_D_S() self.backward_D_R() self.backward_D_Sfeat() self.backward_D_Rfeat() self.optimizer_D.step()
class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG resnet_9blocks' ResNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update D_A and D_B's weights
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num # if opt.cond and opt.netG == "global": # netG_input_nc = opt.ngf self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, cond=opt.cond, n_self_attention=opt.n_self_attention, gpu_ids=self.gpu_ids, img_size=opt.fineSize, vocab_size=opt.vocab_size) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance or opt.cond: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # Encoder network if self.gen_features: self.netE = networks.define_G(netG_input_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, cond=opt.cond, gpu_ids=self.gpu_ids, img_size=opt.vocab_size) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train: pretrained_path = '' self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network( self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network( self.netE, 'E', opt.which_epoch, pretrained_path) # load any pretrained networks on top if possible... # This might help in continue train of local networks without having to start over... if not self.isTrain or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network( self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network( self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter( not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids, opt.vgg19_weights, opt.vocab_size) # Names so we can breakout loss self.loss_names = self.loss_filter( 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.params_G = params self.optimizer_G = torch.optim.Adam( params, lr=opt.lr_G, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam( params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.to(self.opt.device) else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.FloatTensor( torch.Size(oneHot_size)).zero_().to(self.opt.device) input_label = input_label.scatter_( 1, label_map.data.long().to(self.opt.device), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.to(self.opt.device) edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, requires_grad=not infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.to(self.opt.device)) # instance map for feature encoding """ if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) if self.opt.label_feat: inst_map = label_map.cuda() """ if not infer and (self.use_features or self.opt.cond): inst_map = inst_map.float().to(self.opt.device) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input( label, inst, image, feat) # print(f"acm shape: {inst_map.size()}") with autocast(enabled=self.opt.fp16): # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(inst_map, real_image) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label if self.opt.cond: fake_image = self.netG.forward(inst_map, input_concat) else: fake_image = self.netG.forward(input_concat) input_concat_aug = input_concat real_image_aug = real_image fake_image_aug = fake_image if self.opt.ada: params = get_params(self.opt, (self.opt.fineSize, self.opt.fineSize)) transform = get_transform(self.opt, params, is_aug=True) fake_image_aug = batch_transform(fake_image, transform) real_image_aug = batch_transform(real_image, transform) input_concat_aug = batch_transform(input_label, transform) # TODO: send labels to discriminator as well if self.opt.cond: dim = inst_map.size(1) img_size = input_concat.size(-1) pad_len = max(0, img_size - dim) # print(inst_map.size()) # print(f"pad length required: {pad_len}") v = F.pad(inst_map, (0, pad_len)) dim = v.size(1) v = v.unsqueeze(2).repeat( 1, 1, dim).view(-1, 1, dim, dim) # print(v.size()) # print(input_concat.size()) input_label = torch.cat( (v, input_concat_aug), dim=1) with autocast(enabled=self.opt.fp16): # Fake Detection and Loss pred_fake_pool = self.discriminate( input_label, fake_image_aug, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image_aug) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward( torch.cat((input_label, fake_image_aug), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat( pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG( fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake), None if not infer else fake_image] def inference(self, label, inst, image=None): # Encode Inputs image = Variable(image) if image is not None else None input_label, inst_map, real_image, _ = self.encode_input( Variable(label), Variable(inst), image, infer=True) # Fake Generation if self.use_features: if self.opt.use_encoded_image: # encode the real image to get feature map feat_map = self.netE.forward(inst_map, image) else: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(input_concat) elif not self.opt.cond: fake_image = self.netG.forward(input_concat) else: fake_image = self.netG.forward(inst_map, input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join( self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = self.Tensor( inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == int(i)).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] if self.opt.data_type == 16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.to(self.opt.device), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.to(self.opt.device)) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num+1)) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 idx = (inst == int(i)).nonzero() num = idx.size()[0] idx = idx[num//2, :] val = np.zeros((1, feat_num+1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.to(self.opt.device).ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | ( t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :- 1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | ( t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) if self.opt.data_type == 16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam( params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class APDrawingGANModel(BaseModel): def name(self): return 'APDrawingGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # changing the default values parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') # no_lsgan=True, use_lsgan=False parser.set_defaults(dataset_mode='aligned') return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] if self.isTrain and self.opt.no_l1_loss: self.loss_names = ['G_GAN', 'D_real', 'D_fake'] if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss: self.loss_names.append('G_local') if self.isTrain and self.opt.discriminator_local: self.loss_names.append('D_real_local') self.loss_names.append('D_fake_local') self.loss_names.append('G_GAN_local') if self.isTrain: self.loss_names.append('G_chamfer') self.loss_names.append('G_chamfer2') self.loss_names.append('G') print('loss_names', self.loss_names) # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] if self.opt.use_local: self.visual_names += ['fake_B0', 'fake_B1'] self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair'] self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg'] if self.isTrain: self.visual_names += ['dt1', 'dt2', 'dt1gt', 'dt2gt'] if not self.isTrain and self.opt.save2: self.visual_names = ['real_A', 'fake_B'] print('visuals', self.visual_names) # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] if self.opt.discriminator_local: self.model_names += [ 'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG' ] # auxiliary nets for loss calculation self.auxiliary_model_names = ['DT1', 'DT2', 'Line1', 'Line2'] else: # during test time, only load Gs self.model_names = ['G'] self.auxiliary_model_names = [] if self.opt.use_local: self.model_names += [ 'GLEyel', 'GLEyer', 'GLNose', 'GLMouth', 'GLHair', 'GLBG', 'GCombine' ] print('model_names', self.model_names) print('auxiliary_model_names', self.auxiliary_model_names) # define networks (both generator and discriminator) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.nnG) print('netG', opt.netG) if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) print('netD', opt.netD, opt.n_layers_D) if self.opt.discriminator_local: self.netDLEyel = networks.define_D( opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netDLEyer = networks.define_D( opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netDLNose = networks.define_D( opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netDLMouth = networks.define_D( opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netDLHair = networks.define_D( opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.opt.use_local: self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) self.netGCombine = networks.define_G(2 * opt.output_nc, opt.output_nc, opt.ngf, 'combiner', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizers = [] if not self.opt.use_local: print('G_params 1 components') self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) else: G_params = list(self.netG.parameters()) + list( self.netGLEyel.parameters()) + list( self.netGLEyer.parameters()) + list( self.netGLNose.parameters()) + list( self.netGLMouth.parameters()) + list( self.netGLHair.parameters()) + list( self.netGLBG.parameters()) + list( self.netGCombine.parameters()) print('G_params 8 components') self.optimizer_G = torch.optim.Adam(G_params, lr=opt.lr, betas=(opt.beta1, 0.999)) if not self.opt.discriminator_local: print('D_params 1 components') self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) else: D_params = list(self.netD.parameters()) + list( self.netDLEyel.parameters()) + list( self.netDLEyer.parameters()) + list( self.netDLNose.parameters()) + list( self.netDLMouth.parameters()) + list( self.netDLHair.parameters()) + list( self.netDLBG.parameters()) print('D_params 7 components') self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # ==================================auxiliary nets (loaded, parameters fixed)============================= if self.isTrain: self.nc = 1 self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.set_requires_grad(self.netDT1, False) self.set_requires_grad(self.netDT2, False) self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.set_requires_grad(self.netLine1, False) self.set_requires_grad(self.netLine2, False) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if self.opt.use_local: self.real_A_eyel = input['eyel_A'].to(self.device) self.real_A_eyer = input['eyer_A'].to(self.device) self.real_A_nose = input['nose_A'].to(self.device) self.real_A_mouth = input['mouth_A'].to(self.device) self.real_B_eyel = input['eyel_B'].to(self.device) self.real_B_eyer = input['eyer_B'].to(self.device) self.real_B_nose = input['nose_B'].to(self.device) self.real_B_mouth = input['mouth_B'].to(self.device) self.center = input['center'] self.real_A_hair = input['hair_A'].to(self.device) self.real_B_hair = input['hair_B'].to(self.device) self.real_A_bg = input['bg_A'].to(self.device) self.real_B_bg = input['bg_B'].to(self.device) self.mask = input['mask'].to( self.device) # mask for non-eyes,nose,mouth self.mask2 = input['mask2'].to(self.device) # mask for non-bg if self.isTrain: self.dt1gt = input['dt1gt'].to(self.device) self.dt2gt = input['dt2gt'].to(self.device) def forward(self): if not self.opt.use_local: self.fake_B = self.netG(self.real_A) else: self.fake_B0 = self.netG(self.real_A) # EYES, NOSE, MOUTH fake_B_eyel = self.netGLEyel(self.real_A_eyel) fake_B_eyer = self.netGLEyer(self.real_A_eyer) fake_B_nose = self.netGLNose(self.real_A_nose) fake_B_mouth = self.netGLMouth(self.real_A_mouth) self.fake_B_nose = fake_B_nose self.fake_B_eyel = fake_B_eyel self.fake_B_eyer = fake_B_eyer self.fake_B_mouth = fake_B_mouth # HAIR, BG AND PARTCOMBINE fake_B_hair = self.netGLHair(self.real_A_hair) fake_B_bg = self.netGLBG(self.real_A_bg) self.fake_B_hair = self.masked(fake_B_hair, self.mask * self.mask2) self.fake_B_bg = self.masked(fake_B_bg, self.inverse_mask(self.mask2)) self.fake_B1 = self.partCombiner2_bg(fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair, fake_B_bg, self.mask * self.mask2, self.inverse_mask(self.mask2), self.opt.comb_op) # FUSION NET self.fake_B = self.netGCombine( torch.cat([self.fake_B0, self.fake_B1], 1)) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1) ) # we use conditional GANs; we need to feed both input and output to the discriminator pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) if self.opt.discriminator_local: fake_AB_parts = self.getLocalParts(fake_AB) local_names = [ 'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG' ] self.loss_D_fake_local = 0 for i in range(len(fake_AB_parts)): net = getattr(self, 'net' + local_names[i]) pred_fake_tmp = net(fake_AB_parts[i].detach()) addw = self.getaddw(local_names[i]) self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN( pred_fake_tmp, False) * addw self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) if self.opt.discriminator_local: real_AB_parts = self.getLocalParts(real_AB) local_names = [ 'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG' ] self.loss_D_real_local = 0 for i in range(len(real_AB_parts)): net = getattr(self, 'net' + local_names[i]) pred_real_tmp = net(real_AB_parts[i]) addw = self.getaddw(local_names[i]) self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN( pred_real_tmp, True) * addw self.loss_D_real = self.loss_D_real + self.loss_D_real_local # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) if self.opt.discriminator_local: fake_AB_parts = self.getLocalParts(fake_AB) local_names = [ 'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG' ] self.loss_G_GAN_local = 0 for i in range(len(fake_AB_parts)): net = getattr(self, 'net' + local_names[i]) pred_fake_tmp = net(fake_AB_parts[i]) addw = self.getaddw(local_names[i]) self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN( pred_fake_tmp, True) * addw if self.opt.gan_loss_strategy == 1: self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / ( len(fake_AB_parts) + 1) elif self.opt.gan_loss_strategy == 2: self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25 self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local # Second, G(A) = B if not self.opt.no_l1_loss: self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 if self.opt.use_local and not self.opt.no_G_local_loss: local_names = ['eyel', 'eyer', 'nose', 'mouth', 'hair', 'bg'] self.loss_G_local = 0 for i in range(len(local_names)): fakeblocal = getattr(self, 'fake_B_' + local_names[i]) realblocal = getattr(self, 'real_B_' + local_names[i]) addw = self.getaddw(local_names[i]) self.loss_G_local = self.loss_G_local + self.criterionL1( fakeblocal, realblocal) * self.opt.lambda_local * addw # Third, distance transform loss (chamfer matching) if self.fake_B.shape[1] == 3: tmp = self.fake_B[:, 0, ...] * 0.299 + self.fake_B[:, 1, ...] * 0.587 + self.fake_B[:, 2, ...] * 0.114 fake_B_gray = tmp.unsqueeze(1) else: fake_B_gray = self.fake_B if self.real_B.shape[1] == 3: tmp = self.real_B[:, 0, ...] * 0.299 + self.real_B[:, 1, ...] * 0.587 + self.real_B[:, 2, ...] * 0.114 real_B_gray = tmp.unsqueeze(1) else: real_B_gray = self.real_B # d_CM(a_i,G(p_i)) self.dt1 = self.netDT1(fake_B_gray) self.dt2 = self.netDT2(fake_B_gray) dt1 = self.dt1 / 2.0 + 0.5 #[-1,1]->[0,1] dt2 = self.dt2 / 2.0 + 0.5 bs = real_B_gray.shape[0] real_B_gray_line1 = self.netLine1(real_B_gray) real_B_gray_line2 = self.netLine2(real_B_gray) self.loss_G_chamfer = ( dt1[(real_B_gray < 0) & (real_B_gray_line1 < 0)].sum() + dt2[(real_B_gray >= 0) & (real_B_gray_line2 >= 0)].sum()) / bs * self.opt.lambda_chamfer # d_CM(G(p_i),a_i) dt1gt = self.dt1gt dt2gt = self.dt2gt self.dt1gt = (self.dt1gt - 0.5) * 2 self.dt2gt = (self.dt2gt - 0.5) * 2 fake_B_gray_line1 = self.netLine1(fake_B_gray) fake_B_gray_line2 = self.netLine2(fake_B_gray) self.loss_G_chamfer2 = ( dt1gt[(fake_B_gray < 0) & (fake_B_gray_line1 < 0)].sum() + dt2gt[(fake_B_gray >= 0) & (fake_B_gray_line2 >= 0)].sum() ) / bs * self.opt.lambda_chamfer2 self.loss_G = self.loss_G_GAN if 'G_L1' in self.loss_names: self.loss_G = self.loss_G + self.loss_G_L1 if 'G_local' in self.loss_names: self.loss_G = self.loss_G + self.loss_G_local if 'G_chamfer' in self.loss_names: self.loss_G = self.loss_G + self.loss_G_chamfer if 'G_chamfer2' in self.loss_names: self.loss_G = self.loss_G + self.loss_G_chamfer2 self.loss_G.backward() def optimize_parameters(self): self.forward() # update D self.set_requires_grad(self.netD, True) # enable backprop for D if self.opt.discriminator_local: self.set_requires_grad(self.netDLEyel, True) self.set_requires_grad(self.netDLEyer, True) self.set_requires_grad(self.netDLNose, True) self.set_requires_grad(self.netDLMouth, True) self.set_requires_grad(self.netDLHair, True) self.set_requires_grad(self.netDLBG, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G self.set_requires_grad( self.netD, False) # D requires no gradients when optimizing G if self.opt.discriminator_local: self.set_requires_grad(self.netDLEyel, False) self.set_requires_grad(self.netDLEyer, False) self.set_requires_grad(self.netDLNose, False) self.set_requires_grad(self.netDLMouth, False) self.set_requires_grad(self.netDLHair, False) self.set_requires_grad(self.netDLBG, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
class CrowdganModel(BaseModel): def name(self): return 'CrowdganModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.log_para = opt.logPara self.tG = opt.n_frames_G self.output_nc = opt.output_nc self.P_input_nc = opt.P_input_nc self.BP_input_nc = opt.BP_input_nc flowG_input_nc = [opt.P_input_nc + opt.BP_input_nc*(opt.n_frames_G) + 2*(opt.n_frames_G-2)] mapG_input_nc = [opt.P_input_nc, opt.BP_input_nc*2, opt.P_input_nc*(opt.n_frames_G-1)] fusion_input_nc = [opt.ngf + opt.ngf] n_layers_flowG = [6] n_layers_mapG = [4,4] n_layers_postG = [2] if self.isTrain: self.tD = opt.n_frames_D n_layers_D_PB = 3 n_layers_D_PP = 3 n_layers_D_T = 3 netD_PB_input_nc = opt.output_nc + opt.BP_input_nc netD_PP_input_nc = opt.output_nc + opt.output_nc netD_T_input_nc = opt.output_nc * opt.n_frames_D self.mapG = networks.define_G(mapG_input_nc, self.output_nc, opt.ngf, 'Transfer', n_layers_mapG, opt.norm, opt.init_type, self.gpu_ids, n_downsampling=opt.G_n_downsampling, use_dropout=opt.isDropout, fusion_stage=True) self.mapG.load_state_dict(torch.load(opt.mapG_ckpt), strict=False) self.flowNet = FlowSD() self.flowNet.load_state_dict(torch.load(opt.flownet_ckpt)) self.flowNet.eval() self.flowNet = torch.nn.DataParallel(self.flowNet, device_ids=self.gpu_ids).cuda() self.flowG = networks.define_G(flowG_input_nc, 2, opt.ngf, 'FlowEst', n_layers_flowG, opt.norm, opt.init_type, self.gpu_ids, n_downsampling=opt.G_n_downsampling, fusion_stage=True) self.flowG.load_state_dict(torch.load(opt.flowG_ckpt)) self.netG = networks.define_G(fusion_input_nc, 1, opt.ngf, 'Fusion', n_layers_postG, opt.norm, opt.init_type, self.gpu_ids, n_downsampling=opt.P_n_downsampling) if self.isTrain: use_sigmoid = opt.no_lsgan if opt.with_D_PB: self.netD_PB = networks.define_D(netD_PB_input_nc, opt.ndf, 'resnet', n_layers_D_PB, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, n_downsampling = opt.D_n_downsampling) if opt.with_D_PP: self.netD_PP = networks.define_D(netD_PP_input_nc, opt.ndf, 'resnet', n_layers_D_PP, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, n_downsampling = opt.D_n_downsampling) if opt.with_D_T: self.netD_T = networks.define_D(netD_T_input_nc, opt.ndf, 'resnet', n_layers_D_T, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, n_downsampling = opt.D_n_downsampling) self.old_lr = opt.lr self.fake_PP_pool = ImagePool(opt.pool_size) self.fake_PB_pool = ImagePool(opt.pool_size) self.fake_T_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_mapG = torch.optim.Adam(self.mapG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_flowG = torch.optim.Adam(self.flowG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.with_D_PB: self.optimizer_D_PB = torch.optim.Adam(self.netD_PB.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.with_D_PP: self.optimizer_D_PP = torch.optim.Adam(self.netD_PP.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.with_D_T: self.optimizer_D_T = torch.optim.Adam(self.netD_T.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_mapG) self.optimizers.append(self.optimizer_flowG) if opt.with_D_PB: self.optimizers.append(self.optimizer_D_PB) if opt.with_D_PP: self.optimizers.append(self.optimizer_D_PP) if opt.with_D_T: self.optimizers.append(self.optimizer_D_T) print('---------- Networks initialized -------------') networks.print_network(self.netG) networks.print_network(self.mapG) networks.print_network(self.flowG) if self.isTrain: if opt.with_D_PB: networks.print_network(self.netD_PB) if opt.with_D_PP: networks.print_network(self.netD_PP) if opt.with_D_T: networks.print_network(self.netD_T) print('-----------------------------------------------') def forward(self): self.input_prev_I = Variable(self.input_prev_I_set) self.input_prev_D = Variable(self.input_prev_D_set) self.input_last_I = Variable(self.input_last_I_set) self.input_last_D = Variable(self.input_last_D_set) self.input_curr_I = Variable(self.input_curr_I_set) self.input_curr_D = Variable(self.input_curr_D_set) # flowG inference b, _, h, w = self.input_curr_I.size() input_post_I = torch.cat([self.input_prev_I, self.input_curr_I], dim=1)[:,3:].contiguous().view(-1, 3, h, w) input_prev_I = self.input_prev_I.contiguous().view(-1, 3, h, w) flow_predict_input = torch.cat([input_prev_I, input_post_I], dim=1) flow = self.flowNet(flow_predict_input) flow_input = flow.contiguous().view(b, -1, h, w)[:,:-2] flowG_input = torch.cat([self.input_last_I, self.input_prev_D, self.input_curr_D, flow_input.detach()], dim=1) flow_output = self.flowG(flowG_input) flow_predict = flow_output['out'] flow_feature = flow_output['fea'] self.warp = self.resample(self.input_last_I, flow_predict) # mapG inference mapG_input = [self.input_last_I, torch.cat((self.input_last_D, self.input_curr_D), dim=1), self.input_prev_I] map_output = self.mapG(mapG_input) self.res = map_output['out'] map_feature = map_output['fea'] # netG inference G_input = [map_feature, flow_feature] weight = self.netG(G_input) self.fake = self.res * weight + self.warp * (1 - weight) def backward_G(self): # GAN loss if self.opt.with_D_PB: pred_fake_PB = self.netD_PB(torch.cat((self.fake, self.input_curr_D), 1)) self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True) if self.opt.with_D_PP: pred_fake_PP = self.netD_PP(torch.cat((self.fake, self.input_last_I), 1)) self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True) if self.opt.with_D_T: pred_fake_T = self.netD_T(torch.cat((self.input_prev_I, self.fake), 1)) self.loss_G_GAN_T = self.criterionGAN(pred_fake_T, True) if self.opt.with_D_PB: pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN if self.opt.with_D_PP: pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN pair_GANloss = pair_GANloss / 2 else: if self.opt.with_D_PP: pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN if self.opt.with_D_T: temporal_GANloss = self.loss_G_GAN_T * self.opt.lambda_GAN_T # L1 loss self.loss_G_L1 = self.criterionL1(self.fake, self.input_curr_I) * self.opt.lambda_L1 pair_L1loss = self.loss_G_L1 pair_loss = pair_L1loss if self.opt.with_D_PB or self.opt.with_D_PP: pair_loss += pair_GANloss if self.opt.with_D_T: pair_loss += temporal_GANloss pair_loss.backward() self.pair_L1loss = pair_L1loss.data if self.opt.with_D_PB or self.opt.with_D_PP: self.pair_GANloss = pair_GANloss.data if self.opt.with_D_T: self.temporal_GANloss = temporal_GANloss.data def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D # D: take(P, B) as input def backward_D_PB(self): real_PB = torch.cat((self.input_curr_I, self.input_curr_D), 1) # fake_PB = self.fake_PB_pool.query(torch.cat((self.fake_p2, self.input_BP2), 1)) fake_PB = self.fake_PB_pool.query( torch.cat((self.fake, self.input_curr_D), 1).data ) loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB) self.loss_D_PB = loss_D_PB.data # D: take(P, P') as input def backward_D_PP(self): real_PP = torch.cat((self.input_curr_I, self.input_last_I), 1) # fake_PP = self.fake_PP_pool.query(torch.cat((self.fake_p2, self.input_P1), 1)) fake_PP = self.fake_PP_pool.query( torch.cat((self.fake, self.input_last_I), 1).data ) loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP) self.loss_D_PP = loss_D_PP.data # D: take(prev, P`, flows) as input def backward_D_T(self): real_T = torch.cat((self.input_prev_I, self.input_curr_I), 1) fake_T = self.fake_T_pool.query(torch.cat((self.input_prev_I, self.fake), 1).data) loss_D_T = self.backward_D_basic(self.netD_T, real_T, fake_T) self.loss_D_T = loss_D_T.data def optimize_parameters(self): # forward self.forward() self.optimizer_G.zero_grad() self.optimizer_mapG.zero_grad() self.optimizer_flowG.zero_grad() self.backward_G() self.optimizer_G.step() self.optimizer_mapG.step() self.optimizer_flowG.step() # D_PP if self.opt.with_D_PP: for i in range(self.opt.DG_ratio): self.optimizer_D_PP.zero_grad() self.backward_D_PP() self.optimizer_D_PP.step() # D_BP if self.opt.with_D_PB: for i in range(self.opt.DG_ratio): self.optimizer_D_PB.zero_grad() self.backward_D_PB() self.optimizer_D_PB.step() # D_T if self.opt.with_D_T: for i in range(self.opt.DG_ratio): self.optimizer_D_T.zero_grad() self.backward_D_T() self.optimizer_D_T.step() def get_current_errors(self): ret_errors = OrderedDict([('pair_L1loss', self.pair_L1loss)]) if self.opt.with_D_PP: ret_errors['D_PP'] = self.loss_D_PP if self.opt.with_D_PB: ret_errors['D_PB'] = self.loss_D_PB if self.opt.with_D_PB or self.opt.with_D_PP: ret_errors['pair_GANloss'] = self.pair_GANloss if self.opt.with_D_T: ret_errors['temporal_GANloss'] = self.temporal_GANloss return ret_errors def save(self, label): self.save_network(self.netG, 'netG', label, self.gpu_ids) self.save_network(self.mapG, 'mapG', label, self.gpu_ids) self.save_network(self.flowG, 'flowG', label, self.gpu_ids) if self.opt.with_D_PB: self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids) if self.opt.with_D_PP: self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids) if self.opt.with_D_T: self.save_network(self.netD_T, 'netD_T', label, self.gpu_ids)
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.normalize = opt.input_normalize # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') if not self.isTrain or opt.continue_train: print('---------- Loading netG...') self.load_network(self.netG, 'G', opt.which_epoch) print('---------- Loading netG success.') if self.isTrain: print('---------- Loading netD...') self.load_network(self.netD, 'D', opt.which_epoch) print('---------- Loading netD success.') if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self, epoch): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self, epoch): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data, normalize=self.normalize) fake_B = util.tensor2im(self.fake_B.data, normalize=self.normalize) real_B = util.tensor2im(self.real_B.data, normalize=self.normalize) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def get_test_visuals(self): fake_B = util.tensor2im(self.fake_B.data, normalize=self.normalize) return OrderedDict([('fake_B', fake_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleWGANModel(BaseModel): def name(self): return 'CycleWGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions #self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers #self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), # lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.RMSprop(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr) #self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.RMSprop(self.netD_A.parameters(), lr=opt.lr) #self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.RMSprop(self.netD_B.parameters(), lr=opt.lr) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) #loss_D_real = self.criterionGAN(pred_real, True) loss_D_real = torch.mean(pred_real) # Fake pred_fake = netD(fake.detach()) #loss_D_fake = self.criterionGAN(pred_fake, False) loss_D_fake = -torch.mean(pred_fake) # Combined loss loss_D = -(loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) #loss_G_A = self.criterionGAN(pred_fake, True) loss_G_A = -torch.mean(pred_fake) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) #loss_G_B = self.criterionGAN(pred_fake, True) loss_G_B = -torch.mean(pred_fake) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() # clip weights of networks to (-0.01, 0.01) for p in self.netG_A.parameters(): p.data.clamp_(-0.01, 0.01) for p in self.netG_B.parameters(): p.data.clamp_(-0.01, 0.01) for p in self.netD_A.parameters(): p.data.clamp_(-0.01, 0.01) for p in self.netD_B.parameters(): p.data.clamp_(-0.01, 0.01) def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
class HalfGanStyleModel(BaseModel): def name(self): return 'HalfGanStyleModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, int(opt.fineSize / 2), int(opt.fineSize / 2)) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.style_layers = ['r11', 'r21', 'r31', 'r41', 'r51'] # self.content_layers = ['r42'] self.loss_layers = self.style_layers self.loss_fns = [GramMSELoss()] * len(self.style_layers) if torch.cuda.is_available(): self.loss_fns = [loss_fn.cuda() for loss_fn in self.loss_fns] self.vgg = VGG() self.vgg.load_state_dict( torch.load(os.getcwd() + '/Models/' + 'vgg_conv.pth')) for param in self.vgg.parameters(): param.requires_grad = False if torch.cuda.is_available(): self.vgg.cuda() print(self.vgg.state_dict().keys()) self.style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]] # self.content_weights = [1e0] self.weights = self.style_weights # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.start_points = input['A_start_point'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B # TODO here we use real image to create fake_AB fake_AB = self.fake_AB_pool.query(self.fake_B.clone()) # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_B, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = self.real_B.clone() self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): if self.opt.use_style: style_targets = [ GramMatrix()(A).detach() for A in self.vgg(self.real_B, self.style_layers) ] # content_targets = [A.detach() for A in self.vgg(self.real_B, self.content_layers)] targets = style_targets out = self.vgg(self.fake_B, self.loss_layers) layer_losses = [ self.weights[a] * self.loss_fns[a](A, targets[a]) for a, A in enumerate(out) ] # print(layer_losses) loss = sum(layer_losses) self.style_loss = loss loss.backward(retain_graph=True) self.style_loss_value = self.style_loss.item() else: self.style_loss_value = 0 # First, G(A) should fake the discriminator fake_AB = self.fake_B.clone() pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 # self.loss_G = self.loss_G_GAN self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): # print(self.pred_real) # print(self.pred_fake) return OrderedDict([('G_GAN', self.loss_G_GAN.item()), ('G_L1', self.loss_G_L1.item()), ('D_real', self.loss_D_real.item()), ('D_fake', self.loss_D_fake.item()), ('Style', self.style_loss_value)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]), self.start_points def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class ReidHybridCycleGANModel(BaseModel): def name(self): return 'ReidHybridCycleGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default GAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help= 'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. ' 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') parser.add_argument('--lambda_G', type=float, default=1.0, help='weight for Generator loss') # reid parameters parser.add_argument('--droprate', type=float, default=0.5, help='the dropout ratio in reid model') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'rec_A', 'rec_B', 'reid' ] # self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'reid'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_HR_A', 'fake_LR_A', 'rec_HR_A', 'real_LR_A'] visual_names_B = ['real_LR_B', 'fake_HR_B', 'rec_LR_B', 'real_HR_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'D_reid'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B', 'D_reid'] # netG_A: HR -> LR, netG_B: LR -> HR # load/define networks self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) # Load a pretrained resnet model and reset the final connected layer self.netD_reid = networks_reid.ft_net(opt.num_classes, opt.droprate) # the reid network is trained on a single gpu because of the BatchNorm layer self.netD_reid = self.netD_reid.to(self.device) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # GAN self.fake_HR_A_pool = ImagePool(opt.pool_size) # CycleGAN self.fake_LR_A_pool = ImagePool(opt.pool_size) # fake_B_pool self.fake_HR_B_pool = ImagePool(opt.pool_size) # fake_A_pool # define loss functions self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionRec = torch.nn.L1Loss() self.criterionReid = torch.nn.CrossEntropyLoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) # SR optimizer # self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # reid optimizer ignored_params = list(map(id, self.netD_reid.model.fc.parameters())) + \ list(map(id, self.netD_reid.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.netD_reid.parameters()) self.optimizer_D_reid = torch.optim.SGD( [{ 'params': base_params, 'lr': 0.1 * opt.reid_lr }, { 'params': self.netD_reid.model.fc.parameters(), 'lr': opt.reid_lr }, { 'params': self.netD_reid.classifier.parameters(), 'lr': opt.reid_lr }], weight_decay=5e-4, momentum=0.9, nesterov=True) self.optimizer_reid.append(self.optimizer_D_reid) def reset_model_status(self): if self.opt.stage == 1: self.netG_A.train() self.netG_B.train() self.netD_A.train() self.netD_B.train() # for the BatchNorm self.netD_reid.eval() elif self.opt.stage == 0 or self.opt.stage == 2: self.netG_A.train() self.netG_B.train() self.netD_A.train() self.netD_B.train() # for the BatchNorm self.netD_reid.train() def set_input(self, input): self.real_HR_A = input['A'].to(self.device) self.real_LR_B = input['B'].to(self.device) # load the ground-truth low resolution A image self.real_LR_A = input['GT_A'].to(self.device) # load the ground-truth high resolution B image to test the SR quality self.real_HR_B = input['GT_B'].to(self.device) self.image_paths = input['A_paths'] # get the id label for person reid self.A_label = input['A_label'].to(self.device) self.B_label = input['B_label'].to(self.device) def forward(self): # GAN self.fake_HR_A = self.netG_B(self.real_LR_A) # LR -> HR # cycleGAN # HR -> LR -> HR self.fake_LR_A = self.netG_A(self.real_HR_A) # HR -> LR self.rec_HR_A = self.netG_B(self.fake_LR_A) # LR -> HR # LR -> HR -> LR self.fake_HR_B = self.netG_B(self.real_LR_B) # LR -> HR self.rec_LR_B = self.netG_A(self.fake_HR_B) # HR -> LR # self.imags = torch.cat([self.real_HR_A, self.fake_HR_B], 0) # self.labels = torch.cat([self.A_label, self.B_label], 0) # all the HR images self.imgs = torch.cat( [self.real_HR_A, self.fake_HR_B, self.rec_HR_A, self.fake_HR_A], 0) self.labels = torch.cat( [self.A_label, self.B_label, self.A_label, self.A_label]) self.pred_imgs = self.netD_reid(self.imgs) def psnr_eval(self): # compute the PSNR for the test self.bicubic_psnr = networks.compute_psnr(self.real_HR_A, self.real_LR_A) self.psnr = networks.compute_psnr(self.real_HR_A, self.fake_HR_A) def ssim_eval(self): self.bicubic_ssim = networks.compute_ssim(self.real_HR_A, self.real_LR_A) self.ssim = networks.compute_ssim(self.real_HR_A, self.fake_HR_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake # fake.detach() the loss_D do not backward to the net_G pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): # real/fake LR image(G_A) fake_LR_A = self.fake_LR_A_pool.query(self.fake_LR_A) # # used for GAN # self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_LR_A, fake_LR_A) # # used for CycleGAN # self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_LR_B, fake_LR_A) real_LR = torch.cat([self.real_LR_A, self.real_LR_B], 0) self.loss_D_A = self.backward_D_basic(self.netD_A, real_LR, fake_LR_A) def backward_D_B(self): fake_HR_A = self.fake_HR_A_pool.query(self.fake_HR_A) # GAN fake_HR_B = self.fake_HR_B_pool.query(self.fake_HR_B) # # used for GAN # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A, fake_HR_A) # # used for CycleGAN # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A, fake_HR_B) fake_HR = torch.cat([fake_HR_A, fake_HR_B], 0) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A, fake_HR) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_rec = self.opt.lambda_rec lambda_G = self.opt.lambda_G # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_LR_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_LR_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_HR_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_HR_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) # self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_LR_A), True) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_LR_A), True) * lambda_G # GAN loss D_B(G_B(B)) # used for GAN # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_HR_A), True) # used for CycleGAN # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_HR_B), True) fake_HR = torch.cat([self.fake_HR_A, self.fake_HR_B], 0) # self.loss_G_B = self.criterionGAN(self.netD_B(fake_HR), True) self.loss_G_B = self.criterionGAN(self.netD_B(fake_HR), True) * lambda_G # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_HR_A, self.real_HR_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_LR_B, self.real_LR_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B # reconstruct loss of low resolution fake_LR_A(G_A) self.loss_rec_A = self.criterionRec(self.fake_LR_A, self.real_LR_A) * lambda_rec # reconstruct loss of high resolution fake_HR_A(G_B) self.loss_rec_B = self.criterionRec(self.fake_HR_A, self.real_HR_A) * lambda_rec self.loss_rec = self.loss_rec_A + self.loss_rec_B self.loss_G += self.loss_rec _, pred_label_imgs = torch.max(self.pred_imgs, 1) self.corrects += float(torch.sum(pred_label_imgs == self.labels)) self.loss_reid = self.criterionReid(self.pred_imgs, self.labels) self.loss_G = self.loss_G + self.loss_reid self.loss_G.backward() def optimize_parameters(self): # forward self.forward() if self.opt.stage == 1: # G_A and G_B # self.set_requires_grad([self.netD_A, self.netD_B], False) self.set_requires_grad([self.netD_A, self.netD_B, self.netD_reid], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() if self.opt.stage == 0 or self.opt.stage == 2: # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.optimizer_D_reid.zero_grad() self.backward_G() self.optimizer_G.step() self.optimizer_D_reid.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step()
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' @staticmethod def modify_commandline_options(parser, is_train=True): return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.half = opt.half self.use_D = self.opt.lambda_GAN > 0 # specify the training losses you want to print out. The program will call base_model.get_current_losses if(self.use_D): self.loss_names = ['G_GAN', ] else: self.loss_names = [] self.loss_names += ['G_CE', 'G_entr', 'G_entr_hint', ] self.loss_names += ['G_L1_max', 'G_L1_mean', 'G_entr', 'G_L1_reg', ] self.loss_names += ['G_fake_real', 'G_fake_hint', 'G_real_hint', ] self.loss_names += ['0', ] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: if(self.use_D): self.model_names = ['G', 'D'] else: self.model_names = ['G', ] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks num_in = opt.input_nc + opt.output_nc + 1 self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, use_tanh=True, classification=opt.classification) if self.isTrain: use_sigmoid = opt.no_lsgan if self.use_D: self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan).to(self.device) # self.criterionL1 = torch.nn.L1Loss() self.criterionL1 = networks.L1Loss() self.criterionHuber = networks.HuberLoss(delta=1. / opt.ab_norm) # if(opt.classification): self.criterionCE = torch.nn.CrossEntropyLoss() # initialize optimizers self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) if self.use_D: self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_D) if self.half: for model_name in self.model_names: net = getattr(self, 'net' + model_name) net.half() for layer in net.modules(): if(isinstance(layer, torch.nn.BatchNorm2d)): layer.float() print('Net %s half precision' % model_name) # initialize average loss values self.avg_losses = OrderedDict() self.avg_loss_alpha = opt.avg_loss_alpha self.error_cnt = 0 # self.avg_loss_alpha = 0.9993 # half-life of 1000 iterations # self.avg_loss_alpha = 0.9965 # half-life of 200 iterations # self.avg_loss_alpha = 0.986 # half-life of 50 iterations # self.avg_loss_alpha = 0. # no averaging for loss_name in self.loss_names: self.avg_losses[loss_name] = 0 def set_input(self, input): if(self.half): for key in input.keys(): input[key] = input[key].half() AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) # self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.hint_B = input['hint_B'].to(self.device) self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind( self.real_B[:, :, ::4, ::4], self.opt) def forward(self): (self.fake_B_class, self.fake_B_reg) = self.netG( self.real_A, self.hint_B, self.mask_B) # if(self.opt.classification): self.netG.module = self.netG self.fake_B_dec_max = self.netG.module.upsample4( util.decode_max_ab(self.fake_B_class, self.opt)) self.fake_B_distr = self.netG.module.softmax(self.fake_B_class) self.fake_B_dec_mean = self.netG.module.upsample4( util.decode_mean(self.fake_B_distr, self.opt)) self.fake_B_entr = self.netG.module.upsample4(-torch.sum( self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True)) # embed() def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # self.loss_D_fake = 0 # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # self.loss_D_real = 0 # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def compute_losses_G(self): mask_avg = torch.mean(self.mask_B_nc.type( torch.FloatTensor)) + .000001 self.loss_0 = 0 # 0 for plot # classification statistics self.loss_G_CE = self.criterionCE(self.fake_B_class.type(torch.FloatTensor), self.real_B_enc[:, 0, :, :].type(torch.LongTensor)) # cross-entropy loss self.loss_G_entr = torch.mean(self.fake_B_entr.type( torch.FloatTensor)) # entropy of predicted distribution self.loss_G_entr_hint = torch.mean(self.fake_B_entr.type(torch.FloatTensor) * self.mask_B_nc.type( torch.FloatTensor)) / mask_avg # entropy of predicted distribution at hint points # regression statistics self.loss_G_L1_max = 10 * torch.mean(self.criterionL1(self.fake_B_dec_max.type(torch.FloatTensor), self.real_B.type(torch.FloatTensor))) self.loss_G_L1_mean = 10 * torch.mean(self.criterionL1(self.fake_B_dec_mean.type(torch.FloatTensor), self.real_B.type(torch.FloatTensor))) self.loss_G_L1_reg = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.FloatTensor), self.real_B.type(torch.FloatTensor))) # L1 loss at given points self.loss_G_fake_real = 10 * torch.mean(self.criterionL1( self.fake_B_reg * self.mask_B_nc, self.real_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg self.loss_G_fake_hint = 10 * torch.mean(self.criterionL1( self.fake_B_reg * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg self.loss_G_real_hint = 10 * torch.mean(self.criterionL1( self.real_B * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg # self.loss_G_L1 = torch.mean(self.criterionL1(self.fake_B, self.real_B)) # self.loss_G_Huber = torch.mean(self.criterionHuber(self.fake_B, self.real_B)) # self.loss_G_fake_real = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.real_B*self.mask_B_nc)) / mask_avg # self.loss_G_fake_hint = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg # self.loss_G_real_hint = torch.mean(self.criterionHuber(self.real_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg if self.use_D: fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) else: self.loss_G = self.loss_G_CE * self.opt.lambda_A + self.loss_G_L1_reg # self.loss_G = self.loss_G_Huber*self.opt.lambda_A def backward_G(self): self.compute_losses_G() self.loss_G.backward() def optimize_parameters(self): self.forward() if(self.use_D): # update D self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.set_requires_grad(self.netD, False) # update G self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_visuals(self): from collections import OrderedDict visual_ret = OrderedDict() visual_ret['gray'] = util.lab2rgb(torch.cat((self.real_A.type(torch.FloatTensor), torch.zeros_like( self.real_B).type(torch.FloatTensor)), dim=1), self.opt) visual_ret['real'] = util.lab2rgb(torch.cat((self.real_A.type( torch.FloatTensor), self.real_B.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_max'] = util.lab2rgb(torch.cat((self.real_A.type( torch.FloatTensor), self.fake_B_dec_max.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_mean'] = util.lab2rgb(torch.cat((self.real_A.type( torch.FloatTensor), self.fake_B_dec_mean.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.real_A.type( torch.FloatTensor), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['hint'] = util.lab2rgb(torch.cat((self.real_A.type( torch.FloatTensor), self.hint_B.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type( torch.FloatTensor)), self.real_B.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_ab_max'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type( torch.FloatTensor)), self.fake_B_dec_max.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_ab_mean'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type( torch.FloatTensor)), self.fake_B_dec_mean.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type( torch.FloatTensor)), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt) visual_ret['mask'] = self.mask_B_nc.expand( -1, 3, -1, -1).type(torch.FloatTensor) visual_ret['hint_ab'] = visual_ret['mask'] * util.lab2rgb(torch.cat((torch.zeros_like( self.real_A.type(torch.FloatTensor)), self.hint_B.type(torch.FloatTensor)), dim=1), self.opt) C = self.fake_B_distr.shape[1] # scale to [-1, 2], then clamped to [-1, 1] visual_ret['fake_entr'] = torch.clamp( 3 * self.fake_B_entr.expand(-1, 3, -1, -1) / np.log(C) - 1, -1, 1) return visual_ret # return training losses/errors. train.py will print out these errors as debugging information def get_current_losses(self): self.error_cnt += 1 errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): # float(...) works for both scalar tensor and float number self.avg_losses[name] = float( getattr(self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name] errors_ret[name] = (1 - self.avg_loss_alpha) / (1 - self.avg_loss_alpha**self.error_cnt) * self.avg_losses[name] # errors_ret['|ab|_gt'] = float(torch.mean(torch.abs(self.real_B[:,1:,:,:])).cpu()) # errors_ret['|ab|_pr'] = float(torch.mean(torch.abs(self.fake_B[:,1:,:,:])).cpu()) return errors_ret
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' @staticmethod def modify_commandline_options(parser, is_train=True): # changing the default values to match the pix2pix paper # (https://phillipi.github.io/pix2pix/) parser.set_defaults(pool_size=0) parser.set_defaults(no_lsgan=True) parser.set_defaults(norm='batch') parser.set_defaults(dataset_mode='aligned') parser.set_defaults(which_model_netG='unet_256') if is_train: parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG(self.real_A) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() # update D self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G self.set_requires_grad(self.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
class TransferModel(nn.Module): def __init__(self): super(TransferModel, self).__init__() def name(self): return 'TransferModel' def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) nb = opt.batchSize size = opt.fineSize self.input_P1_set = self.Tensor(nb, opt.P_input_nc, size, size) self.input_KP1_set = self.Tensor(nb, opt.BP_input_nc, size, size) self.input_P2_set = self.Tensor(nb, opt.P_input_nc, size, size) self.input_KP2_set = self.Tensor(nb, opt.BP_input_nc, size, size) self.input_SPL1_set = self.Tensor(nb, 1, size, size) self.input_SPL2_set = self.Tensor(nb, 1, size, size) self.input_SPL1_onehot_set = self.Tensor(nb, 12, size, size) self.input_SPL2_onehot_set = self.Tensor(nb, 12, size, size) self.input_syn_set = self.Tensor(nb, opt.P_input_nc, size, size) input_nc = [ opt.P_input_nc, opt.BP_input_nc + opt.BP_input_nc, opt.P_input_nc ] self.netG = networks.define_G(input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan if opt.with_D_PB: self.netD_PB = networks.define_D( 3 + 18, opt.ndf, opt.which_model_netD, opt.n_layers_D, 'instance', use_sigmoid, opt.init_type, self.gpu_ids, not opt.no_dropout_D, n_downsampling=opt.D_n_downsampling) if opt.with_D_PP: self.netD_PP = networks.define_D( opt.P_input_nc + opt.P_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, 'instance', use_sigmoid, opt.init_type, self.gpu_ids, not opt.no_dropout_D, n_downsampling=opt.D_n_downsampling) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG, 'netG', which_epoch) if self.isTrain: if opt.with_D_PB: self.load_network(self.netD_PB, 'netD_PB', which_epoch) if opt.with_D_PP: self.load_network(self.netD_PP, 'netD_PP', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_PP_pool = ImagePool(opt.pool_size) self.fake_PB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) #define shape loss if False: #self._opt.mask_bce: self.parseLoss = torch.nn.BCELoss() else: self.parseLoss = CrossEntropyLoss2d() if opt.L1_type == 'origin': self.criterionL1 = torch.nn.L1Loss() elif opt.L1_type == 'l1_plus_perL1': self.criterionL1 = L1_plus_perceptualLoss( opt.lambda_A, opt.lambda_B, opt.perceptual_layers, self.gpu_ids, opt.percep_is_l1) else: raise Excption('Unsurportted type of L1!') # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.with_D_PB: self.optimizer_D_PB = torch.optim.Adam( self.netD_PB.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.with_D_PP: self.optimizer_D_PP = torch.optim.Adam( self.netD_PP.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) if opt.with_D_PB: self.optimizers.append(self.optimizer_D_PB) if opt.with_D_PP: self.optimizers.append(self.optimizer_D_PP) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: if opt.with_D_PB: networks.print_network(self.netD_PB) if opt.with_D_PP: networks.print_network(self.netD_PP) print('-----------------------------------------------') def set_input(self, input): input_P1, input_KP1, input_SPL1 = input['P1'], input['KP1'], input[ 'SPL1'] input_P2, input_KP2, input_SPL2 = input['P2'], input['KP2'], input[ 'SPL2'] input_SPL1_onehot = input['SPL1_onehot'] input_SPL2_onehot = input['SPL2_onehot'] self.input_SPL1_onehot_set.resize_( input_SPL1_onehot.size()).copy_(input_SPL1_onehot) self.input_SPL2_onehot_set.resize_( input_SPL2_onehot.size()).copy_(input_SPL2_onehot) self.input_SPL1_set.resize_(input_SPL1.size()).copy_(input_SPL1) self.input_SPL2_set.resize_(input_SPL2.size()).copy_(input_SPL2) #qinput_syn = input_syn[:,:,:,40:216] self.input_P1_set.resize_(input_P1.size()).copy_(input_P1) self.input_KP1_set.resize_(input_KP1.size()).copy_(input_KP1) self.input_P2_set.resize_(input_P2.size()).copy_(input_P2) self.input_KP2_set.resize_(input_KP2.size()).copy_(input_KP2) self.image_paths = input['P1_path'][0] + '___' + input['P2_path'][0] def forward(self): self.input_P1 = Variable(self.input_P1_set) self.input_KP1 = Variable(self.input_KP1_set) self.input_SPL1 = Variable(self.input_SPL1_set) self.input_P2 = Variable(self.input_P2_set) self.input_KP2 = Variable(self.input_KP2_set) self.input_SPL2 = Variable(self.input_SPL2_set) #bs 1 256 176 # print(self.input_SPL2.shape) self.input_SPL1_onehot = Variable(self.input_SPL1_onehot_set) self.input_SPL2_onehot = Variable(self.input_SPL2_onehot_set) G_input = [ self.input_P1, torch.cat((self.input_KP1, self.input_KP2), 1), self.input_SPL1_onehot, self.input_SPL2_onehot ] self.fake_p2, self.fake_parse = self.netG(G_input) def test(self): self.input_P1 = Variable(self.input_P1_set) self.input_KP1 = Variable(self.input_KP1_set) self.input_SPL1 = Variable(self.input_SPL1_set) self.input_P2 = Variable(self.input_P2_set) self.input_KP2 = Variable(self.input_KP2_set) self.input_SPL2 = Variable(self.input_SPL2_set) self.input_SPL1_onehot = Variable(self.input_SPL1_onehot_set) self.input_SPL2_onehot = Variable(self.input_SPL2_onehot_set) G_input = [ self.input_P1, torch.cat((self.input_KP1, self.input_KP2), 1), self.input_SPL1_onehot, self.input_SPL2_onehot ] self.fake_p2, self.fake_parse = self.netG(G_input) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D(self): self.pred_fake = self.fake_PB_pool.query( torch.cat((self.input_KP2, self.fake_p2), 1).data) self.pred_real = torch.cat((self.input_KP2, self.input_P2), 1) self.loss_DPB_fake = self.backward_D_basic(self.netD_PB, self.pred_real, self.pred_fake).item() self.pred_fake = self.fake_PP_pool.query( torch.cat((self.fake_p2, self.input_P1), 1).data) self.pred_real = torch.cat((self.input_P2, self.input_P1), 1) self.loss_DPP_fake = self.backward_D_basic(self.netD_PP, self.pred_real, self.pred_fake).item() def backward_G(self): mask = self.input_SPL2.squeeze(1).long() self.maskloss1 = self.parseLoss(self.fake_parse, mask) L1_per = self.criterionL1(self.fake_p2, self.input_P2) self.loss_G_L1 = L1_per[0] pred_fake = self.netD_PB(torch.cat((self.input_KP2, self.fake_p2), 1)) pred_fake_pp = self.netD_PP(torch.cat((self.fake_p2, self.input_P1), 1)) self.L1 = L1_per[1] self.per = L1_per[2] self.loss_G_GAN = (self.criterionGAN(pred_fake, True) + self.criterionGAN(pred_fake_pp, True)) / 2 self.loss_mask = self.loss_G_L1 + self.loss_G_GAN * self.opt.lambda_GAN + self.maskloss1 self.loss_mask.backward() def optimize_parameters(self): self.forward() self.optimizer_D_PB.zero_grad() self.optimizer_D_PP.zero_grad() self.backward_D() self.optimizer_D_PB.step() self.optimizer_D_PP.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): ret_errors = OrderedDict() if self.opt.with_D_PB or self.opt.with_D_PP: ret_errors['L1_plus_perceptualLoss'] = self.loss_G_L1 ret_errors['percetual'] = self.per ret_errors['L1'] = self.L1 ret_errors['PB'] = self.loss_DPB_fake ret_errors['PP'] = self.loss_DPP_fake ret_errors['pair_GANloss'] = self.loss_G_GAN.data.item() ret_errors['parsing1'] = self.maskloss1.data.item() return ret_errors def get_current_visuals(self): height, width = self.input_P1.size(2), self.input_P1.size(3) input_P1 = util.tensor2im(self.input_P1.data) input_P2 = util.tensor2im(self.input_P2.data) input_SPL1 = util.tensor2im( torch.argmax(self.input_SPL1_onehot, axis=1, keepdim=True).data, True) input_SPL2 = util.tensor2im( torch.argmax(self.input_SPL2_onehot, axis=1, keepdim=True).data, True) input_KP1 = util.draw_pose_from_map(self.input_KP1.data)[0] input_KP2 = util.draw_pose_from_map(self.input_KP2.data)[0] fake_shape2 = util.tensor2im( torch.argmax(self.fake_parse, axis=1, keepdim=True).data, True) fake_p2 = util.tensor2im(self.fake_p2.data) vis = np.zeros((height, width * 8, 3)).astype(np.uint8) #h, w, c vis[:, :width, :] = input_P1 vis[:, width:width * 2, :] = input_KP1 vis[:, width * 2:width * 3, :] = input_SPL1 if input_P2.shape[1] == 256: vis[:, width * 3:width * 4, :] = input_P2[:, 40:216, :] else: vis[:, width * 3:width * 4, :] = input_P2 vis[:, width * 4:width * 5, :] = input_KP2 vis[:, width * 5:width * 6, :] = input_SPL2 vis[:, width * 6:width * 7, :] = fake_shape2 vis[:, width * 7:, :] = fake_p2 ret_visuals = OrderedDict([('vis', vis)]) return ret_visuals def save(self, label): self.save_network(self.netG, 'netG', label, self.gpu_ids) if self.opt.with_D_PB: self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids) if self.opt.with_D_PP: self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids) # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch_label, gpu_ids): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr)
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain and (not opt.no_gan): use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain and (not opt.no_gan): self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if opt.use_l2: self.criterionL1 = torch.nn.MSELoss() else: self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) if not opt.no_gan: self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain and (not opt.no_gan): networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): if not self.opt.no_gan: # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) else: self.loss_G_GAN = 0 # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() if not self.opt.no_gan: self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): if not self.opt.no_gan: return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) else: return OrderedDict([ ('G_L1', self.loss_G_L1.data[0]) ]) def get_current_visuals(self): real_A_img, real_A_prior = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) if self.opt.output_nc == 1: fake_B_postprocessed = util.postprocess_parsing(fake_B, self.isTrain) fake_B_color = util.paint_color(fake_B_postprocessed) real_B_color = util.paint_color(util.postprocess_parsing(real_B, self.isTrain)) if self.opt.output_nc == 1: return OrderedDict([ ('real_A_img', real_A_img), ('real_A_prior', real_A_prior), ('fake_B', fake_B), ('fake_B_postprocessed', fake_B_postprocessed), ('fake_B_color', fake_B_color), ('real_B', real_B), ('real_B_color', real_B_color)] ) else: return OrderedDict([ ('real_A_img', real_A_img), ('real_A_prior', real_A_prior), ('fake_B', fake_B), ('real_B', real_B)] ) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) if not self.opt.no_gan: self.save_network(self.netD, 'D', label, self.gpu_ids)
class MultiModel(BaseModel): def name(self): return 'MultiGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.opt = opt self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) if opt.vgg > 0: self.vgg_loss = networks.PerceptualLoss() self.vgg_loss.cuda() self.vgg = networks.load_vgg16("./model") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) skip = True if opt.skip > 0 else False self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions if opt.use_wgan: self.criterionGAN = networks.DiscLossWGANGP() else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if opt.use_mse: self.criterionCycle = torch.nn.MSELoss() else: self.criterionCycle = torch.nn.L1Loss() self.criterionL1 = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) if opt.isTrain: self.netG_A.train() self.netG_B.train() else: self.netG_A.eval() self.netG_B.eval() print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) if self.opt.skip == 1: self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) else: self.rec_B = self.netG_A.forward(self.fake_A) def predict(self): self.real_A = Variable(self.input_A, volatile=True) # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) if self.opt.skip == 1: latent_real_A = util.tensor2im(self.latent_real_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A), ("rec_A", rec_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)]) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) if self.opt.use_wgan: loss_D_real = pred_real.mean() else: loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) if self.opt.use_wgan: loss_D_fake = pred_fake.mean() else: loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss if self.opt.use_wgan: loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty( netD, real.data, fake.data) else: loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. if self.opt.skip == 1: self.idt_A, _ = self.netG_A.forward(self.real_B) else: self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) # = self.latent_real_A + self.opt.skip * self.real_A pred_fake = self.netD_A.forward(self.fake_B) if self.opt.use_wgan: self.loss_G_A = -pred_fake.mean() else: self.loss_G_A = self.criterionGAN(pred_fake, True) self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1 if self.opt.use_wgan: self.loss_G_B = -pred_fake.mean() else: self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss if lambda_A > 0: self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A else: self.loss_cycle_A = 0 # Backward cycle loss # = self.latent_fake_A + self.opt.skip * self.fake_A if lambda_B > 0: if self.opt.skip == 1: self.rec_B, self.latent_fake_A = self.netG_A.forward( self.fake_A) else: self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B else: self.loss_cycle_B = 0 self.loss_vgg_a = self.vgg_loss.compute_vgg_loss( self.vgg, self.fake_A, self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0 self.loss_vgg_b = self.vgg_loss.compute_vgg_loss( self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \ self.loss_vgg_a + self.loss_vgg_b + \ self.loss_idt_A + self.loss_idt_B # self.loss_G = self.L1_AB + self.L1_BA self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] L1 = (self.L1_AB + self.L1_BA).data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0] ) / self.opt.vgg if self.opt.vgg > 0 else 0 if self.opt.identity > 0: idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0] if self.opt.lambda_A > 0.0: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg), ("idt", idt)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg), ("idt", idt)) else: if self.opt.lambda_A > 0.0: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg)) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) if self.opt.skip > 0: latent_real_A = util.tensor2im(self.latent_real_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) if self.opt.identity > 0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) if self.opt.lambda_A > 0.0: rec_A = util.tensor2im(self.rec_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.skip > 0: latent_fake_A = util.tensor2im(self.latent_fake_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ("idt_A", idt_A), ("idt_B", idt_B)]) else: if self.opt.skip > 0: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: if self.opt.lambda_A > 0.0: rec_A = util.tensor2im(self.rec_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.skip > 0: latent_fake_A = util.tensor2im(self.latent_fake_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) else: if self.opt.skip > 0: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('real_B', real_B), ('fake_A', fake_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), ('fake_A', fake_A)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): if self.opt.new_lr: lr = self.old_lr / 2 else: lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: if self.opt.verbose: print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params':[value],'lr':opt.lr}] else: params += [{'params':[value],'lr':0.0}] else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type==16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ] def inference(self, label, inst): # Encode Inputs input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) # Fake Generation if self.use_features: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path).item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == i).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] if self.opt.data_type==16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num+1)) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 idx = (inst == i).nonzero() num = idx.size()[0] idx = idx[num//2,:] val = np.zeros((1, feat_num+1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) if self.opt.data_type==16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class RefinedDCPModel(BaseModel): """ This class implements the RefineDNet model, for learning single image dehazing without paired data. It adopts the basic backbone networks provided by CycleGAN. The model training requires '--dataset_mode unpaired' dataset. By default, it uses a '--netR_T unet_trans_256' U-Net refiner, a '--netR_J resnet_9blocks' ResNet refiner, and a '--netD basic' discriminator (PatchGAN introduced by pix2pix). """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ parser.set_defaults( no_dropout=True) # default CycleGAN did not use dropout if is_train: parser.add_argument('--lambda_G', type=float, default=0.05, help='weight for loss_G_single') parser.add_argument( '--lambda_identity', type=float, default=1, help= 'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument('--lambda_rec_I', type=float, default=1, help='weight for loss_rec_I') parser.add_argument('--lambda_tv', type=float, default=1, help='weight for TV loss of refine_T') parser.add_argument('--lambda_vgg', type=float, default=0, help='weight for loss_vgg') parser.add_argument('--netR_T', type=str, default='unet_trans_256', help='specify generator architecture') parser.add_argument('--netR_J', type=str, default='resnet_9blocks', help='specify generator architecture') return parser def __init__(self, opt): """Initialize the RefineDNet class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_single', 'G_single', 'rec_I', 'TV_T', 'idt_J', 'vgg' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> if self.isTrain: self.visual_names = [ 'real_I', 'dcp_T_vis', 'refine_T_vis', 'out_T_vis', 'dcp_J', 'refine_J', 'rec_I', 'rec_J', 'map_A', 'real_J', 'ref_real_J' ] else: self.visual_names = [ 'real_I', 'dcp_T_vis', 'refine_T_vis', 'out_T_vis', 'dcp_J', 'refine_J', 'rec_I', 'rec_J', 'map_A' ] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. if self.isTrain: self.model_names = ['Refiner_T', 'Refiner_J', 'D'] else: # during test time, only load Gs self.model_names = ['Refiner_T', 'Refiner_J'] # define networks (both Generators and discriminators) self.netG_DCP = networks.init_net( networks.DCPDehazeGenerator(), gpu_ids=self.gpu_ids) # use default setting for DCP self.netRefiner_T = networks.define_G(opt.input_nc + 1, 1, opt.ngf, opt.netR_T, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netRefiner_J = networks.define_G(opt.input_nc + opt.output_nc, opt.output_nc, opt.ngf, opt.netR_J, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define discriminators self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_I_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_J_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionRec = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionTV = networks.TVLoss() self.criterionVGG = networks.VGGLoss( ) if self.opt.lambda_vgg > 0.0 else None # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netRefiner_T.parameters(), self.netRefiner_J.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # display the architecture of each part # print(self.netRefiner_T) # print(self.netRefiner_J) # if self.isTrain: # print(self.netD) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. """ self.real_I = input['haze'].to(self.device) # [-1, 1] self.image_paths = input['paths'] if self.isTrain: self.real_J = input['clear'].to(self.device) # [-1, 1] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" dcp_J, self.dcp_T, self.dcp_A = self.netG_DCP(self.real_I) #scale to [-1,1] self.dcp_J = (torch.clamp(dcp_J, 0, 1) - 0.5) / 0.5 # output scale [0,1] self.refine_T, self.out_T = self.netRefiner_T( torch.cat((self.real_I, self.dcp_T), 1)) self.refine_J = self.netRefiner_J( torch.cat((self.real_I, self.dcp_J), 1)) # reconstruct haze image shape = self.refine_J.shape dcp_A_scale = self.dcp_A self.map_A = (dcp_A_scale).reshape( (1, 3, 1, 1)).repeat(1, 1, shape[2], shape[3]) refine_T_map = self.refine_T.repeat(1, 3, 1, 1) self.rec_I = util.synthesize_fog(self.refine_J, refine_T_map, self.map_A) self.rec_J = util.reverse_fog(self.real_I, refine_T_map, self.map_A) def test(self): """Forward function used in test time. This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop It also calls <compute_visuals> to produce additional visualization results """ with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" # rescale to [-1,1] for visdom self.refine_T_vis = (self.refine_T - 0.5) / 0.5 self.out_T_vis = (self.out_T - 0.5) / 0.5 self.dcp_T_vis = (self.dcp_T - 0.5) / 0.5 # self.map_A_vis = (self.map_A - 0.5)/0.5 def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D(self): fake_J = self.fake_I_pool.query(self.refine_J) self.loss_D_single = self.backward_D_basic(self.netD, self.real_J, fake_J) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_tv = self.opt.lambda_tv lambda_G = self.opt.lambda_G lambda_rec_I = self.opt.lambda_rec_I lambda_vgg = self.opt.lambda_vgg # Generator losses for rec_I and refine_J self.loss_G_single = self.criterionGAN(self.netD(self.refine_J), True) * lambda_G # Reconstrcut loss self.loss_rec_I = self.criterionRec(self.rec_I, self.real_I) * lambda_rec_I # perecptual loss self.loss_vgg = self.criterionVGG( self.refine_J, self.dcp_J) * lambda_vgg if lambda_vgg > 0.0 else 0 # TV loss self.loss_TV_T = self.criterionTV( self.out_T) * lambda_tv if lambda_tv > 0.0 else 0 # Identity loss, ||refiner_J(real_J) - real_J|| self.ref_real_J = self.netRefiner_J( torch.cat((self.real_I, self.real_J), 1)) self.loss_idt_J = self.criterionIdt(self.ref_real_J, self.real_J)*lambda_idt \ if lambda_idt > 0.0 \ else 0 self.loss_G = self.loss_G_single + self.loss_rec_I + self.loss_idt_J \ + self.loss_TV_T \ + self.loss_vgg self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad( self.netD, False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D() # calculate gradients for D_A self.optimizer_D.step() # update D_A and D_B's weights
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
class ImageRefineModel(BaseModel): def name(self): return 'ImageRefineModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain opt.output_nc = opt.input_nc # load/define networks self.netG = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, tanh=True) self.flow_remapper = networks.flow_remapper(size=opt.fineSize, batch=opt.batchSize, gpu_ids=opt.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) grid = np.zeros((opt.fineSize, opt.fineSize, 2)) for i in range(grid.shape[0]): for j in range(grid.shape[1]): grid[i, j, 0] = j grid[i, j, 1] = i grid /= (opt.fineSize / 2) grid -= 1 self.grid = torch.from_numpy( grid).cuda().float() #Variable(torch.from_numpy(grid)) self.grid = self.grid.view(1, self.grid.size(0), self.grid.size(1), self.grid.size(2)) self.grid = Variable(self.grid) intrinsics = np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((1, 3, 3)) intrinsics_inv = np.linalg.inv(np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3)) self.intrinsics = Variable( torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand( opt.batchSize, 3, 3) self.intrinsics_inv = Variable( torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand( opt.batchSize, 3, 3) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] input_C = input['C'] if len(self.gpu_ids) > 0: input_C = input_C.cuda(self.gpu_ids[0], async=True) self.input_C = input_C def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) self.real_C = Variable(self.input_C) pose = np.array([ 0, 0, 0, 0, -np.pi / 4., 0, ]).reshape((1, 6)) pose = Variable(torch.from_numpy(pose.astype( np.float32)).cuda()).expand(self.opt.batchSize, 6) self.forward_map = inverse_warp(self.real_A, self.real_C, pose, self.intrinsics, self.intrinsics_inv) self.backward_map = self.flow_remapper(self.forward_map, self.forward_map) self.fake_B_raw = F.grid_sample(self.real_A, self.backward_map) self.fake_B = self.netG(self.fake_B_raw) # no backprop gradients def test(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) self.real_C = Variable(self.input_C) pose = np.array([ 0, 0, 0, 0, -np.pi / 8., 0, ]).reshape((1, 6)) pose = Variable(torch.from_numpy(pose.astype( np.float32)).cuda()).expand(self.opt.batchSize, 6) self.forward_map = inverse_warp(self.real_A, self.real_C, pose, self.intrinsics, self.intrinsics_inv) self.backward_map = self.flow_remapper(self.forward_map, self.forward_map) self.fake_B_raw = F.grid_sample(self.real_A, self.backward_map) self.fake_B = self.netG(self.fake_B_raw) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN( pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.opt.lambda_gan * self.criterionGAN( pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN( pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A # self.loss_G_flow = self.criterionL1(self.forward_flow, self.real_C) * self.opt.lambda_flow self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward(retain_graph=True) def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) fake_B_raw = util.tensor2im(self.fake_B_raw.data) real_B = util.tensor2im(self.real_B.data) # real_C = util.tensor2im(self.real_C.data) forward_map = util.tensor2im(self.forward_map.permute(0, 3, 1, 2).data) backward_map = util.tensor2im( self.backward_map.permute(0, 3, 1, 2).data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), \ ('forward_map', forward_map), ('backward_map', backward_map),('fake_B_raw', fake_B_raw),]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)
class CycleDRPANModel(BaseModel): def name(self): return 'CycleDRPANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'R_A', 'GR_A'] # specify the images you want to save/display. The program will call base_model.get_current_visuals if self.isTrain: visual_names_A = ['real_A', 'fake_B', 'rec_A', 'fake_Br', 'real_Ar', 'fake_Bf', 'real_Af'] visual_names_B = ['real_B', 'fake_A', 'rec_B', 'fake_Ar', 'real_Br', 'fake_Af', 'real_Bf'] else: visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'R_A', 'R_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X), R_A(R_Y), R_B(R_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netR_A = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netR_B = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_R_A = torch.optim.Adam(self.netR_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_R_B = torch.optim.Adam(self.netR_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_R_A) self.optimizers.append(self.optimizer_R_B) self.proposal = Proposal() # self.batchsize = opt.batchSize # self.label_r = torch.FloatTensor(self.batchsize) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def reviser_A(self): # training with reviser for n_step in range(3): fake_B_ = self.netG_A(self.real_A) output = self.netD_A(fake_B_.detach()) # proposal self.fake_Br, self.real_Ar, self.fake_Bf, self.real_Af, self.fake_ABf, self.real_ABr = self.proposal.forward_A(self.real_B, fake_B_, self.real_A, output) # train with real self.netD_A.zero_grad() output_r = self.netR_A(self.real_ABr.detach()) self.loss_errR_real_A = self.criterionGAN(output_r, True) self.loss_errR_real_A.backward() # train with fake output_r = self.netR_A(self.fake_ABf.detach()) self.loss_errR_fake_A = self.criterionGAN(output_r, False) self.loss_errR_fake_A.backward() self.loss_R_A = (self.loss_errR_real_A + self.loss_errR_fake_A) / 2 self.optimizer_R_A.step() # train Generator with reviser self.netG_A.zero_grad() output_r = self.netR_A(self.fake_ABf) self.loss_GR_A = self.criterionGAN(output_r, True) self.loss_GR_A.backward() self.optimizer_G.step() def reviser_B(self): # training with reviser for n_step in range(3): fake_A_ = self.netG_B(self.real_B) output = self.netD_B(fake_A_.detach()) # proposal self.fake_Ar, self.real_Br, self.fake_Af, self.real_Bf, self.fake_BAf, self.real_BAr = self.proposal.forward_B(self.real_A, fake_A_, self.real_B, output) # train with real self.netD_B.zero_grad() output_r = self.netR_B(self.real_BAr.detach()) self.loss_errR_real_B = self.criterionGAN(output_r, True) self.loss_errR_real_B.backward() # train with fake output_r = self.netR_B(self.fake_BAf.detach()) self.loss_errR_fake_B = self.criterionGAN(output_r, False) self.loss_errR_fake_B.backward() self.loss_R_B = (self.loss_errR_real_B + self.loss_errR_fake_B) / 2 self.optimizer_R_B.step() # train Generator with reviser self.netG_B.zero_grad() output_r = self.netR_B(self.fake_BAf) self.errGAN_r = self.criterionGAN(output_r, True) self.loss_GR_B = self.errGAN_r self.loss_GR_B.backward() self.optimizer_G.step() def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() # R_A and R_B self.set_requires_grad([self.netR_A, self.netR_B], True) self.optimizer_R_A.zero_grad() self.optimizer_R_B.zero_grad() self.reviser_A() self.reviser_B()
class VIGANModel(BaseModel): def name(self): return 'VIGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) self.load_network(self.AE, 'AE', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionAE = torch.nn.MSELoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE = torch.optim.Adam(self.AE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE_GA_GB = torch.optim.Adam( itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) networks.print_network(self.AE) print('-----------------------------------------------') def set_input(self, images_a, images_b): input_A =images_a input_B =images_b self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # Autoencoder loss: fakeA self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) # Autoencoder loss: fakeB AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B) #get image pathss def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() ############################################################################ # Define backward function for VIGAN ############################################################################ def backward_AE_pretrain(self): # Autoencoder loss AErealA, AErealB = self.AE.forward(self.real_A, self.real_B) self.loss_AE_pre = self.criterionAE(AErealA, self.real_A) + self.criterionAE(AErealB, self.real_A) self.loss_AE_pre.backward() def backward_AE(self): # fake data self.fake_B = self.netG_A.forward(self.real_A) self.fake_A = self.netG_B.forward(self.real_B) # Autoencoder loss: fakeA AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) self.loss_AE_fA_rB = ( self.criterionAE(AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1 # Autoencoder loss: fakeB AErealA, AEfakeB = self.AE.forward(self.real_A, self.fake_B) self.loss_AE_rA_fB = ( self.criterionAE(AErealA, self.real_A) + self.criterionAE(AEfakeB, self.real_B)) * 1 # combined loss self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) * 0.5 self.loss_AE.backward() # input is vector def backward_D_A_AE(self): fake_B = self.AEfakeB self.loss_D_A_AE = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B_AE(self): fake_A = self.AEfakeA self.loss_D_B_AE = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_AE_GA_GB(self): lambda_C = self.opt.lambda_C lambda_D = self.opt.lambda_D # fake data # G_A(A) self.fake_B = self.netG_A.forward(self.real_A) # G_B(B) self.fake_A = self.netG_B.forward(self.real_B) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A_AE = self.criterionCycle(self.rec_A, self.real_A) # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B_AE = self.criterionCycle(self.rec_B, self.real_B) # Autoencoder loss: fakeA self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) self.loss_AE_fA_rB = (self.criterionAE(self.AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1 # Autoencoder loss: fakeB AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B) self.loss_AE_rA_fB = (self.criterionAE(AErealA, self.real_A) + self.criterionAE(self.AEfakeB, self.real_B)) * 1 self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) # D loss pred_fake = self.netD_A.forward(self.AEfakeB) self.loss_AE_GA = self.criterionGAN(pred_fake, True) pred_fake = self.netD_B.forward(self.AEfakeA) self.loss_AE_GB = self.criterionGAN(pred_fake, True) self.loss_AE_GA_GB = lambda_C * ( self.loss_AE_GA + self.loss_AE_GB) + \ lambda_D * self.loss_AE + 1 * (self.loss_cycle_A_AE + self.loss_cycle_B_AE) self.loss_AE_GA_GB.backward() ######################################################################################################### def optimize_parameters_pretrain_cycleGAN(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() ############################################################################ # Define optimize function for VIGAN ############################################################################ def optimize_parameters_pretrain_AE(self): # forward self.forward() # AE self.optimizer_AE.zero_grad() self.backward_AE_pretrain() self.optimizer_AE.step() def optimize_parameters(self): # forward self.forward() # AE+G_A+G_B for i in range(2): self.optimizer_AE_GA_GB.zero_grad() self.backward_AE_GA_GB() self.optimizer_AE_GA_GB.step() for i in range(1): # D_A self.optimizer_D_A_AE.zero_grad() self.backward_D_A_AE() self.optimizer_D_A_AE.step() # D_B self.optimizer_D_B_AE.zero_grad() self.backward_D_B_AE() self.optimizer_D_B_AE.step() ############################################################################################ # Get errors for visualization ############################################################################################ def get_current_errors_cycle(self): AE_D_A = self.loss_D_A.data[0] AE_G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] AE_D_B = self.loss_D_B.data[0] AE_G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B)]) def get_current_errors(self): D_A = self.loss_D_A_AE.data[0] G_A = self.loss_AE_GA.data[0] Cyc_A = self.loss_cycle_A_AE.data[0] D_B = self.loss_D_B_AE.data[0] G_B = self.loss_AE_GB.data[0] Cyc_B = self.loss_cycle_B_AE.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) AE_fake_A = util.tensor2im(self.AEfakeA.view(1,1,28,28).data) AE_fake_B = util.tensor2im(self.AEfakeB.view(1,1,28,28).data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A), ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) self.save_network(self.AE, 'AE', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr