def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------')
def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D)
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------')
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose)
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' @staticmethod def modify_commandline_options(parser, is_train=True): # changing the default values to match the pix2pix paper # (https://phillipi.github.io/pix2pix/) parser.set_defaults(pool_size=0) parser.set_defaults(no_lsgan=True) parser.set_defaults(norm='batch') parser.set_defaults(dataset_mode='aligned') parser.set_defaults(which_model_netG='unet_256') if is_train: parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG(self.real_A) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() # update D self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G self.set_requires_grad(self.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
class VIGANModel(BaseModel): def name(self): return 'VIGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) self.load_network(self.AE, 'AE', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionAE = torch.nn.MSELoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE = torch.optim.Adam(self.AE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE_GA_GB = torch.optim.Adam( itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) networks.print_network(self.AE) print('-----------------------------------------------') def set_input(self, images_a, images_b): input_A =images_a input_B =images_b self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # Autoencoder loss: fakeA self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) # Autoencoder loss: fakeB AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B) #get image pathss def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() ############################################################################ # Define backward function for VIGAN ############################################################################ def backward_AE_pretrain(self): # Autoencoder loss AErealA, AErealB = self.AE.forward(self.real_A, self.real_B) self.loss_AE_pre = self.criterionAE(AErealA, self.real_A) + self.criterionAE(AErealB, self.real_A) self.loss_AE_pre.backward() def backward_AE(self): # fake data self.fake_B = self.netG_A.forward(self.real_A) self.fake_A = self.netG_B.forward(self.real_B) # Autoencoder loss: fakeA AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) self.loss_AE_fA_rB = ( self.criterionAE(AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1 # Autoencoder loss: fakeB AErealA, AEfakeB = self.AE.forward(self.real_A, self.fake_B) self.loss_AE_rA_fB = ( self.criterionAE(AErealA, self.real_A) + self.criterionAE(AEfakeB, self.real_B)) * 1 # combined loss self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) * 0.5 self.loss_AE.backward() # input is vector def backward_D_A_AE(self): fake_B = self.AEfakeB self.loss_D_A_AE = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B_AE(self): fake_A = self.AEfakeA self.loss_D_B_AE = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_AE_GA_GB(self): lambda_C = self.opt.lambda_C lambda_D = self.opt.lambda_D # fake data # G_A(A) self.fake_B = self.netG_A.forward(self.real_A) # G_B(B) self.fake_A = self.netG_B.forward(self.real_B) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A_AE = self.criterionCycle(self.rec_A, self.real_A) # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B_AE = self.criterionCycle(self.rec_B, self.real_B) # Autoencoder loss: fakeA self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B) self.loss_AE_fA_rB = (self.criterionAE(self.AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1 # Autoencoder loss: fakeB AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B) self.loss_AE_rA_fB = (self.criterionAE(AErealA, self.real_A) + self.criterionAE(self.AEfakeB, self.real_B)) * 1 self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) # D loss pred_fake = self.netD_A.forward(self.AEfakeB) self.loss_AE_GA = self.criterionGAN(pred_fake, True) pred_fake = self.netD_B.forward(self.AEfakeA) self.loss_AE_GB = self.criterionGAN(pred_fake, True) self.loss_AE_GA_GB = lambda_C * ( self.loss_AE_GA + self.loss_AE_GB) + \ lambda_D * self.loss_AE + 1 * (self.loss_cycle_A_AE + self.loss_cycle_B_AE) self.loss_AE_GA_GB.backward() ######################################################################################################### def optimize_parameters_pretrain_cycleGAN(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() ############################################################################ # Define optimize function for VIGAN ############################################################################ def optimize_parameters_pretrain_AE(self): # forward self.forward() # AE self.optimizer_AE.zero_grad() self.backward_AE_pretrain() self.optimizer_AE.step() def optimize_parameters(self): # forward self.forward() # AE+G_A+G_B for i in range(2): self.optimizer_AE_GA_GB.zero_grad() self.backward_AE_GA_GB() self.optimizer_AE_GA_GB.step() for i in range(1): # D_A self.optimizer_D_A_AE.zero_grad() self.backward_D_A_AE() self.optimizer_D_A_AE.step() # D_B self.optimizer_D_B_AE.zero_grad() self.backward_D_B_AE() self.optimizer_D_B_AE.step() ############################################################################################ # Get errors for visualization ############################################################################################ def get_current_errors_cycle(self): AE_D_A = self.loss_D_A.data[0] AE_G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] AE_D_B = self.loss_D_B.data[0] AE_G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B)]) def get_current_errors(self): D_A = self.loss_D_A_AE.data[0] G_A = self.loss_AE_GA.data[0] Cyc_A = self.loss_cycle_A_AE.data[0] D_B = self.loss_D_B_AE.data[0] G_B = self.loss_AE_GB.data[0] Cyc_B = self.loss_cycle_B_AE.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) AE_fake_A = util.tensor2im(self.AEfakeA.view(1,1,28,28).data) AE_fake_B = util.tensor2im(self.AEfakeB.view(1,1,28,28).data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A), ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) self.save_network(self.AE, 'AE', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'D_C_A', 'D_C_B', 'G_C', 'cycle_C', 'cycle_C', 'idt_C_A', 'idt_C_B' ] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_C = [ 'real_C', 'fake_C_A', 'fake_C_B', 'rec_C_A', 'rec_C_B' ] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') visual_names_C.append('idt_C_A') visual_names_C.append('idt_C_B') self.visual_names = visual_names_A + visual_names_B + visual_names_C # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = [ 'G_A', 'G_B', 'D_A', 'D_B', 'G_C_A', 'G_C_B', 'D_C' ] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B', 'G_C_A', 'G_C_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_C_A = networks.define_G(opt.input_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_C_B = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_C = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.fake_C_A_pool = ImagePool(opt.pool_size) self.fake_C_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters(), self.netG_C_A.parameters(), self.netG_C_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters(), self.netD_C.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D)
class gea_ganModel(BaseModel): def name(self): return 'gea_ganModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.batchSize = opt.batchSize self.fineSize = opt.fineSize # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize, opt.fineSize) if self.opt.rise_sobelLoss: self.sobelLambda = 0 else: self.sobelLambda = self.opt.lambda_sobel # load/define networks which_netG = opt.which_model_netG self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, which_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: self.D_channel = opt.input_nc + opt.output_nc use_sigmoid = opt.no_lsgan self.netD = networks.define_D(self.D_channel, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if not self.isTrain: self.netG.eval() if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions if self.opt.labelSmooth: self.criterionGAN = networks.GANLoss_smooth( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B self.fake_sobel = networks.sobelLayer(self.fake_B) fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real self.real_sobel = networks.sobelLayer(self.real_B).detach() real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_sobelL1 = self.criterionL1( self.fake_sobel, self.real_sobel) * self.sobelLambda self.loss_G += self.loss_sobelL1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('G_sobelL1', self.loss_sobelL1.data[0]), ('D_GAN', self.loss_D.data[0])]) def get_current_visuals(self): real_A = util.tensor2array(self.real_A.data) fake_B = util.tensor2array(self.fake_B.data) real_B = util.tensor2array(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def get_current_img(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr def update_sobel_lambda(self, epochNum): self.sobelLambda = self.opt.lambda_sobel / 20 * epochNum print('update sobel lambda: %f' % (self.sobelLambda))
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0]) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0]) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) # If this is training phase if self.isTrain: use_sigmoid = opt.no_lsgan # do not use least square GAN by default self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) # If this is non-training phase/continue training phase if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: # build up so called history pool self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gpu_ids=opt.gpu_ids) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() if opt.use_prcp: self.criterionPrcp = networks.PrcpLoss(opt.weight_path, opt.bias_path, opt.perceptual_level, tensor=self.Tensor, gpu_ids=opt.gpu_ids) # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) # no back propagation self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) # A recover self.real_B = Variable(self.input_B, volatile=True) # no back propagation self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # B recover # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake # stop back propagate this part of loss back to generator, as we only care about discriminator here pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Cycle loss # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class StackGANModel(BaseModel): def name(self): return 'StackGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # define tensors self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.input_base = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks if self.opt.conv3d: # one layer for considering a conv filter for each of the 26 channels self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids) # Generator of the GlyphNet self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) #Generator of the OrnaNet as an Encoder and a Decoder self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) if self.opt.conditional: # not applicable for non-conditional case use_sigmoid = opt.no_lsgan if opt.which_model_preNet != 'none': self.preNet_A = networks.define_preNet(self.opt.input_nc_1+self.opt.output_nc_1, self.opt.input_nc_1+self.opt.output_nc_1, which_model_preNet=opt.which_model_preNet,norm=opt.norm, gpu_ids=self.gpu_ids) nif = opt.input_nc_1+opt.output_nc_1 netD_norm = opt.norm self.netD1 = networks.define_D(nif, opt.ndf, opt.which_model_netD, opt.n_layers_D, netD_norm, use_sigmoid, True, self.gpu_ids) if self.isTrain: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) if self.opt.print_weights: for key in self.netE1.state_dict().keys(): print key, 'random_init, mean,std:', torch.mean(self.netE1.state_dict()[key]),torch.std(self.netE1.state_dict()[key]) for key in self.netDE1.state_dict().keys(): print key, 'random_init, mean,std:', torch.mean(self.netDE1.state_dict()[key]),torch.std(self.netDE1.state_dict()[key]) if not self.isTrain: print "Load generators from their pretrained models..." if opt.no_Style2Glyph: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) else: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', str(int(opt.which_epoch)+int(opt.which_epoch1))) self.load_network(self.netG, 'G', str(int(opt.which_epoch)+int(opt.which_epoch1))) self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1))) self.load_network(self.netDE1, 'DE1', str(int(opt.which_epoch1))) self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1))) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) if self.isTrain: if opt.continue_train: print "Load StyleNet from its pretrained model..." self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if self.isTrain: self.fake_AB1_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionL1 = torch.nn.L1Loss() self.criterionMSE = torch.nn.MSELoss() # initialize optimizers if self.opt.conv3d: self.optimizer_G_3d = torch.optim.Adam(self.netG_3d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.which_model_preNet != 'none': self.optimizer_preA = torch.optim.Adam(self.preNet_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') if self.opt.conv3d: networks.print_network(self.netG_3d) networks.print_network(self.netG) networks.print_network(self.netE1) networks.print_network(self.netDE1) if opt.which_model_preNet != 'none': networks.print_network(self.preNet_A) networks.print_network(self.netD1) print('-----------------------------------------------') self.initial = True def set_input(self, input): input_A0 = input['A'] input_B0 = input['B'] self.input_A0.resize_(input_A0.size()).copy_(input_A0) self.input_B0.resize_(input_B0.size()).copy_(input_B0) self.image_paths = input['B_paths'] if self.opt.base_font: input_base = input['A_base'] self.input_base.resize_(input_base.size()).copy_(input_base) b,c,m,n = self.input_base.size() real_base = self.Tensor(self.opt.output_nc,self.opt.input_nc_1, m,n) for batch in range(self.opt.output_nc): if not self.opt.rgb_in and self.opt.rgb_out: real_base[batch,0,:,:] = self.input_base[0,batch,:,:] real_base[batch,1,:,:] = self.input_base[0,batch,:,:] real_base[batch,2,:,:] = self.input_base[0,batch,:,:] self.real_base = Variable(real_base, requires_grad=False) if self.opt.isTrain: self.id_ = {} self.obs = [] for i,im in enumerate(self.image_paths): self.id_[int(im.split('/')[-1].split('.png')[0].split('_')[-1])]=i self.obs += [int(im.split('/')[-1].split('.png')[0].split('_')[-1])] for i in list(set(range(self.opt.output_nc))-set(self.obs)): self.id_[i] = np.random.randint(low=0, high=len(self.image_paths)) self.num_disc = self.opt.output_nc +1 def all2observed(self, tensor_all): b,c,m,n = self.real_A0.size() self.out_id = self.obs tensor_gt = self.Tensor(b,self.opt.input_nc_1, m,n) for batch in range(b): if not self.opt.rgb_in and self.opt.rgb_out: tensor_gt[batch,0,:,:] = tensor_all.data[batch,self.out_id[batch],:,:] tensor_gt[batch,1,:,:] = tensor_all.data[batch,self.out_id[batch],:,:] tensor_gt[batch,2,:,:] = tensor_all.data[batch,self.out_id[batch],:,:] else: #TODO tensor_gt[batch,:,:,:] = tensor_all.data[batch,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:] return tensor_gt def forward0(self): self.real_A0 = Variable(self.input_A0) if self.opt.conv3d: self.real_A0_indep = self.netG_3d.forward(self.real_A0.unsqueeze(2)) self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2)) else: self.fake_B0 = self.netG.forward(self.real_A0) if self.initial: if self.opt.orna: self.fake_B0_init = self.real_A0 else: self.fake_B0_init = self.fake_B0 def forward1(self, inp_grad=False): b,c,m,n = self.real_A0.size() self.batch_ = b self.out_id = self.obs real_A1 = self.Tensor(self.opt.output_nc,self.opt.input_nc_1, m,n) if self.opt.orna: inp_orna = self.fake_B0_init else: inp_orna = self.fake_B0 for batch in range(self.opt.output_nc): if not self.opt.rgb_in and self.opt.rgb_out: real_A1[batch,0,:,:] = inp_orna.data[self.id_[batch],batch,:,:] real_A1[batch,1,:,:] = inp_orna.data[self.id_[batch],batch,:,:] real_A1[batch,2,:,:] = inp_orna.data[self.id_[batch],batch,:,:] else: #TODO real_A1[batch,:,:,:] = inp_orna.data[batch,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:] if self.initial: self.real_A1_init = Variable(real_A1, requires_grad=False) self.initial = False self.real_A1_s = Variable(real_A1, requires_grad=inp_grad) self.real_A1 = self.real_A1_s self.fake_B1_emb = self.netE1.forward(self.real_A1) self.fake_B1 = self.netDE1.forward(self.fake_B1_emb) self.real_B1 = Variable(self.input_B0) self.real_A1_gt_s = Variable(self.all2observed(inp_orna), requires_grad=True) self.real_A1_gt = (self.real_A1_gt_s) self.fake_B1_gt_emb = self.netE1.forward(self.real_A1_gt) self.fake_B1_gt = self.netDE1.forward(self.fake_B1_gt_emb) obs_ = torch.cuda.LongTensor(self.obs) if self.opt.gpu_ids else LongTensor(self.obs) if self.opt.base_font: real_base_gt = index_select(self.real_base, 0, obs_) self.real_base_gt = (Variable(real_base_gt.data, requires_grad=False)) def add_noise_disc(self,real): #add noise to the discriminator target labels #real: True/False? if self.opt.noisy_disc: rand_lbl = random.random() if rand_lbl<0.6: label = (not real) else: label = (real) else: label = (real) return label # no backprop gradients def test(self): self.real_A0 = Variable(self.input_A0, volatile=True) if self.opt.conv3d: self.real_A0_indep = self.netG_3d.forward(self.real_A0.unsqueeze(2)) self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2)) else: self.fake_B0 = self.netG.forward(self.real_A0) b,c,m,n = self.fake_B0.size() #for test time: we need to generate output for all of the glyphs in each input image if self.opt.rgb_in: self.batch_ = c/self.opt.input_nc_1 else: self.batch_ = c self.out_id = range(self.batch_) real_A1 = self.Tensor(self.batch_,self.opt.input_nc_1, m,n) if self.opt.orna: inp_orna = self.real_A0 else: inp_orna = self.fake_B0 for batch in range(self.batch_): if not self.opt.rgb_in and self.opt.rgb_out: real_A1[batch,0,:,:] = inp_orna.data[:,self.out_id[batch],:,:] real_A1[batch,1,:,:] = inp_orna.data[:,self.out_id[batch],:,:] real_A1[batch,2,:,:] = inp_orna.data[:,self.out_id[batch],:,:] else: real_A1[batch,:,:,:] = inp_orna.data[:,self.out_id[batch]*np.array(self.opt.input_nc_1):(self.out_id[batch]+1)*np.array(self.opt.input_nc_1),:,:] self.real_A1 = Variable(real_A1, volatile=True) fake_B1_emb = self.netE1.forward(self.real_A1.detach()) self.fake_B1 = self.netDE1.forward(fake_B1_emb) self.real_B1 = Variable(self.input_B0, volatile=True) #get image paths def get_image_paths(self): return self.image_paths def prepare_data(self): if self.opt.conditional: if self.opt.base_font: self.first_pair = self.real_base self.first_pair_gt = self.real_base_gt else: self.first_pair = Variable(self.real_A1.data, requires_grad=False) self.first_pair_gt = Variable(self.real_A1_gt.data,requires_grad=False) def backward_D1(self): b,c,m,n = self.fake_B1.size() # Fake # stop backprop to the generator by detaching fake_B label_fake = self.add_noise_disc(False) if self.opt.conditional: fake_AB1 = self.fake_AB1_pool.query(torch.cat((self.first_pair, self.fake_B1),1)) self.pred_fake1 = self.netD1.forward(fake_AB1.detach()) if self.opt.which_model_preNet != 'none': #transform the input transformed_AB1 = self.preNet_A.forward(fake_AB1.detach()) self.pred_fake_GL = self.netD1.forward(transformed_AB1) self.loss_D1_fake = 0 self.loss_D1_fake += self.criterionGAN(self.pred_fake1, label_fake) if self.opt.which_model_preNet != 'none': self.loss_D1_fake += self.criterionGAN(self.pred_fake_GL, label_fake) # Real label_real = self.add_noise_disc(True) if self.opt.conditional: real_AB1 = torch.cat((self.first_pair_gt, self.real_B1), 1).detach() self.pred_real1 = self.netD1.forward(real_AB1) if self.opt.which_model_preNet != 'none': transformed_real_AB1 = self.preNet_A.forward(real_AB1) self.pred_real1_GL = self.netD1.forward(transformed_real_AB1) self.loss_D1_real = 0 self.loss_D1_real += self.criterionGAN(self.pred_real1, label_real) if self.opt.which_model_preNet != 'none': self.loss_D1_real += self.criterionGAN(self.pred_real1_GL, label_real) # Combined loss self.loss_D1 = (self.loss_D1_fake + self.loss_D1_real) * 0.5 self.loss_D1.backward() def backward_G(self, pass_grad, iter): b,c,m,n = self.fake_B0.size() if not self.opt.lambda_C or (iter>700): self.loss_G_L1 = Variable(torch.zeros(1)) else: weight_val = 10.0 weights = torch.ones(b,c,m,n).cuda() if self.opt.gpu_ids else torch.ones(b,c,m,n) obs_ = torch.cuda.LongTensor(self.obs) if self.opt.gpu_ids else LongTensor(self.obs) weights.index_fill_(1,obs_,weight_val) weights=Variable(weights, requires_grad=False) self.loss_G_L1 = self.criterionL1(weights * self.fake_B0, weights * self.fake_B0_init.detach()) * self.opt.lambda_C self.loss_G_L1.backward(retain_graph=True) self.fake_B0.backward(pass_grad) def backward_G1(self,iter): # First, G(A) should fake the discriminator if self.opt.conditional: fake_AB = torch.cat((self.first_pair.detach(), self.fake_B1), 1) pred_fake = self.netD1.forward(fake_AB) if self.opt.which_model_preNet != 'none': #transform the input transformed_AB1 = self.preNet_A.forward(fake_AB) pred_fake_GL = self.netD1.forward(transformed_AB1) self.loss_G1_GAN = 0 self.loss_G1_GAN += self.criterionGAN(pred_fake, True) if self.opt.which_model_preNet != 'none': self.loss_G1_GAN += self.criterionGAN(pred_fake_GL, True) self.loss_G1_L1 = self.criterionL1(self.fake_B1_gt, self.real_B1) * self.opt.lambda_A fake_B1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.fake_B1,dim=1,keepdim=True)-0.9)) real_A1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_A1,dim=1,keepdim=True)-0.9)) self.loss_G1_MSE_rgb2gay = self.criterionMSE(fake_B1_gray, real_A1_gray.detach())* self.opt.lambda_A/3.0 real_A1_gt_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_A1_gt,dim=1,keepdim=True)-0.9)) real_B1_gray = 1-torch.nn.functional.sigmoid(100*(torch.mean(self.real_B1,dim=1,keepdim=True)-0.9)) self.loss_G1_MSE_gt = self.criterionMSE(real_A1_gt_gray, real_B1_gray)* self.opt.lambda_A # update generator less frequently if iter<200: rate_gen = 90 else: rate_gen = 60 if (iter%rate_gen)==0: self.loss_G1 = self.loss_G1_GAN + self.loss_G1_L1 + self.loss_G1_MSE_gt G1_L1_update = True G1_GAN_update = True else: self.loss_G1 = self.loss_G1_L1 + self.loss_G1_MSE_gt G1_L1_update = True G1_GAN_update = False if (iter<200): self.loss_G1 += self.loss_G1_MSE_rgb2gay else: self.loss_G1 += 0.01*self.loss_G1_MSE_rgb2gay self.loss_G1.backward(retain_graph=True) (b,c,m,n) = self.real_A1_s.size() self.real_A1_grad = torch.zeros(b,c,m,n).cuda() if self.opt.gpu_ids else torch.zeros(b,c,m,n) if G1_L1_update: for batch in self.obs: self.real_A1_grad[batch,:,:,:] = self.real_A1_gt_s.grad.data[self.id_[batch],:,:,:] def optimize_parameters(self,iter): self.forward0() self.forward1(inp_grad=True) self.prepare_data() if self.opt.which_model_preNet != 'none': self.optimizer_preA.zero_grad() self.optimizer_D1.zero_grad() self.backward_D1() self.optimizer_D1.step() if self.opt.which_model_preNet != 'none': self.optimizer_preA.step() self.optimizer_E1.zero_grad() self.optimizer_DE1.zero_grad() self.backward_G1(iter) self.optimizer_DE1.step() self.optimizer_E1.step() self.loss_G_L1 = Variable(torch.zeros(1)) def optimize_parameters_Stacked(self,iter): self.forward0() self.forward1(inp_grad=True) self.prepare_data() if self.opt.which_model_preNet != 'none': self.optimizer_preA.zero_grad() self.optimizer_D1.zero_grad() self.backward_D1() self.optimizer_D1.step() if self.opt.which_model_preNet != 'none': self.optimizer_preA.step() self.optimizer_E1.zero_grad() self.optimizer_DE1.zero_grad() self.backward_G1(iter) self.optimizer_DE1.step() self.optimizer_E1.step() b,c,m,n = self.fake_B0.size() self.optimizer_G.zero_grad() if self.opt.conv3d: self.optimizer_G_3d.zero_grad() b,c,m,n = self.fake_B0.size() fake_B0_grad = torch.zeros(b,c,m,n).cuda() if self.opt.gpu_ids else torch.zeros(b,c,m,n) real_A_grad = self.real_A1_grad for batch in range(self.opt.input_nc): if not self.opt.rgb_in and self.opt.rgb_out: fake_B0_grad[self.id_[batch], batch,:,:] += torch.mean(real_A_grad[batch,:,:,:],0)*3 else: #TODO fake_B0_grad[batch, self.obs[batch]*np.array(self.opt.input_nc_1):(self.obs[batch]+1)*np.array(self.opt.input_nc_1),:,:] = real_A_grad[batch,:,:,:] self.backward_G(fake_B0_grad, iter) self.optimizer_G.step() if self.opt.conv3d: self.optimizer_G_3d.step() def get_current_errors(self): return OrderedDict([('G1_GAN', self.loss_G1_GAN.item()), ('G1_L1', self.loss_G1_L1.item()), ('G1_MSE_gt', self.loss_G1_MSE_gt.item()), ('G1_MSE', self.loss_G1_MSE_rgb2gay.item()), ('D1_real', self.loss_D1_real.item()), ('D1_fake', self.loss_D1_fake.item()), ('G_L1', self.loss_G_L1.item()) ]) def get_current_visuals(self): real_A1 = self.real_A1.data.clone() g,c,m,n = real_A1.size() fake_B = self.fake_B1.data.clone() real_B = self.real_B1.data.clone() if self.opt.isTrain: real_A_all = real_A1 fake_B_all = fake_B else: real_A_all = self.Tensor(real_B.size(0),real_B.size(1),real_A1.size(2),real_A1.size(2)*real_A1.size(0)) fake_B_all = self.Tensor(real_B.size(0),real_B.size(1),real_A1.size(2),fake_B.size(2)*fake_B.size(0)) for b in range(g): real_A_all[:,:,:,self.out_id[b]*m:m*(self.out_id[b]+1)] = real_A1[b,:,:,:] fake_B_all[:,:,:,self.out_id[b]*m:m*(self.out_id[b]+1)] = fake_B[b,:,:,:] real_A = util.tensor2im(real_A_all) fake_B = util.tensor2im(fake_B_all) real_B = util.tensor2im(self.real_B1.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): if not self.opt.no_Style2Glyph: try: G_label = str(int(label)+int(self.opt.which_epoch)) except: G_label = label if self.opt.conv3d: self.save_network(self.netG_3d, 'G_3d', G_label, self.gpu_ids) self.save_network(self.netG, 'G', G_label, self.gpu_ids) self.save_network(self.netE1, 'E1', label, self.gpu_ids) self.save_network(self.netDE1, 'DE1', label, self.gpu_ids) self.save_network(self.netD1, 'D1', label, self.gpu_ids) if self.opt.which_model_preNet != 'none': self.save_network(self.preNet_A, 'PRE_A', label, gpu_ids=self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd if self.opt.which_model_preNet != 'none': for param_group in self.optimizer_preA.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D1.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr for param_group in self.optimizer_E1.param_groups: param_group['lr'] = lr for param_group in self.optimizer_DE1.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3,0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3,0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) if infer: with torch.no_grad(): input_label = Variable(input_label) else: input_label = Variable(input_label) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) if self.opt.label_feat: inst_map = label_map.cuda() return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation # print(f'Fake gen. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) # Fake Detection and Loss # print(f'Fake detection and loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss # print(f'Real detection and loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) # print(f'GAN loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss # print(f'GAN feature matching loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss # print(f'VGG feature matching loss. Cur mem allocated: {torch.cuda.memory_allocated() / 1e6} MB') loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ] def inference(self, label, inst, image=None): # Encode Inputs image = Variable(image) if image is not None else None input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) # Fake Generation if self.use_features: if self.opt.use_encoded_image: # encode the real image to get feature map feat_map = self.netE.forward(real_image, inst_map) else: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(input_concat) else: fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == int(i)).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] if self.opt.data_type==16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): with torch.no_grad(): image = Variable(image.cuda()) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num+1)) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 idx = (inst == int(i)).nonzero() num = idx.size()[0] idx = idx[num//2,:] val = np.zeros((1, feat_num+1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].item() val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) if self.opt.data_type==16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
def __init__(self, opt): # raise problems using super(),so use BaseModel.__init__(self.opt) instead # super(ComboGANModel, self).__init__(opt) BaseModel.__init__(self, opt) self.n_domains = opt.n_domains self.d_domains = opt.d_domains self.batchSize = opt.batchSize self.DA, self.DB, self.DC = None, None, None # classify the domains self.real = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) # images in style 1 self.real_B = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) # images in style 2 self.real_C = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) # images in style 3 # images without edges self.edge_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.edge_B = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.edge_C = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.netG_framework, opt.input_nc, opt.output_nc, opt.ngf, opt.netG_n_blocks, opt.netG_n_shared, self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: self.netD = networks.define_D(opt.netD_framework, opt.output_nc, opt.ndf, opt.netD_n_layers, self.d_domains, opt.norm, self.gpu_ids) self.classifier = networks.define_classifier(opt.classifier_framework, gpu_ids=self.gpu_ids) # for image classification self.vgg = networks.define_VGG(init_weights_=opt.vgg_pretrained_mode, feature_mode_=True, gpu_id_=self.gpu_ids) # using conv4_4 layer # load model weights if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) if self.isTrain and not opt.init: self.load_network(self.netD, 'D', which_epoch) self.load_network(self.classifier, 'A', which_epoch) print("load weights of pretrained model successfully") # test the function of encoder part if opt.encoder_test: which_epoch = opt.which_epoch self.load_part_network(self.netG, 'G', which_epoch, 0) print("load weights of encoder successfully") # ======================training initialization========================================== if self.isTrain: self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)] # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # use not opt.no_lsgan self.criterionContent = torch.nn.L1Loss() self.classGAN = networks.ClassLoss(tensor=self.Tensor) # initialize optimizers self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) self.classifier.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) # initialize loss storage self.loss_D, self.loss_G_gan = [0]*self.n_domains, [0]*self.n_domains # discriminator loss in details self.loss_D_real = [0]*self.n_domains self.loss_D_fake = [0]*self.n_domains self.loss_D_edge = [0]*self.n_domains self.loss_D_class_real = [0]*self.n_domains self.loss_G_class = [0]*self.n_domains self.loss_D_class_edge_fake = [0]*self.n_domains # generator loss in details self.loss_content = [0]*self.n_domains self.loss_content_2 = [0] * self.n_domains self.loss_content_3 = [0] * self.n_domains # initialize loss multipliers self.lambda_con = opt.lambda_content self.lambda_cla = opt.lambda_classfication print('---------- Networks initialized -------------') print(self.netG) if self.isTrain: print(self.netD) print(self.classifier) print('-----------------------------------------------')
def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'low_freq_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'low_freq_B' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = [ 'real_A', 'fake_B', 'rec_A', 'blur_real_A', 'blur_fake_B' ] visual_names_B = [ 'real_B', 'fake_A', 'rec_B', 'blur_real_B', 'blur_fake_A' ] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B self.use_lowfreq_loss = True if opt.lambda_low_freq > 0 else False # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, noise_generator=True) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, noise_generator=False) self.netGaussian = SimpleGaussian(gaussian_std=opt.low_pass_std) self.netGaussian.apply(weights_init_Gaussian) if len(self.gpu_ids) > 0: assert (torch.cuda.is_available()) self.netGaussian.to(self.gpu_ids[0]) self.netGaussian = torch.nn.DataParallel( self.netGaussian, self.gpu_ids) # multi-GPUs if self.isTrain: # define discriminators use_sigmoid = False if (opt.gan_mode == 'lsgan') else True self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionLowFreq = torch.nn.MSELoss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D)
def initialize(self, opt): BaseModel.initialize(self, opt) # Parameters for WGAN self.use_which_gan = opt.use_which_gan # CycleGAN or CycleWGAN or ICycleWGAN self.wgan_clip_upper = opt.wgan_clip_upper self.wgan_clip_lower = opt.wgan_clip_lower self.wgan_n_critic = opt.wgan_n_critic self.wgan_optimizer = opt.wgan_optimizer # rmsprop self.wgan_train_critics = True # Not sure about this part # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_which_gan=self.use_which_gan, use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # L1 norm self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers if (self.use_which_gan == 'CycleWGAN'): if (self.wgan_optimizer == 'rmsprop'): self.optimizer_G = torch.optim.RMSprop(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.wgan_lrG) self.optimizer_D_A = torch.optim.RMSprop( self.netD_A.parameters(), lr=opt.wgan_lrD) self.optimizer_D_B = torch.optim.RMSprop( self.netD_B.parameters(), lr=opt.wgan_lrD) elif (self.use_which_gan == 'CycleGAN' or self.use_which_gan == 'ICycleWGAN'): self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------')
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain and (not opt.no_gan): use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain and (not opt.no_gan): self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if opt.use_l2: self.criterionL1 = torch.nn.MSELoss() else: self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) if not opt.no_gan: self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain and (not opt.no_gan): networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): if not self.opt.no_gan: # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) else: self.loss_G_GAN = 0 # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() if not self.opt.no_gan: self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): if not self.opt.no_gan: return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) else: return OrderedDict([ ('G_L1', self.loss_G_L1.data[0]) ]) def get_current_visuals(self): real_A_img, real_A_prior = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) if self.opt.output_nc == 1: fake_B_postprocessed = util.postprocess_parsing(fake_B, self.isTrain) fake_B_color = util.paint_color(fake_B_postprocessed) real_B_color = util.paint_color(util.postprocess_parsing(real_B, self.isTrain)) if self.opt.output_nc == 1: return OrderedDict([ ('real_A_img', real_A_img), ('real_A_prior', real_A_prior), ('fake_B', fake_B), ('fake_B_postprocessed', fake_B_postprocessed), ('fake_B_color', fake_B_color), ('real_B', real_B), ('real_B_color', real_B_color)] ) else: return OrderedDict([ ('real_A_img', real_A_img), ('real_A_prior', real_A_prior), ('fake_B', fake_B), ('real_B', real_B)] ) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) if not self.opt.no_gan: self.save_network(self.netD, 'D', label, self.gpu_ids)
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: if self.opt.verbose: print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params':[value],'lr':opt.lr}] else: params += [{'params':[value],'lr':0.0}] else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) if self.opt.data_type==16: input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: real_image = Variable(real_image.data.cuda()) # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, label, inst, image, feat, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation if self.use_features: if not self.opt.load_features: feat_map = self.netE.forward(real_image, inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) # Fake Detection and Loss pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = 0 if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # Only return the fake_B image if necessary to save BW return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ] def inference(self, label, inst): # Encode Inputs input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) # Fake Generation if self.use_features: # sample clusters from precomputed features feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label fake_image = self.netG.forward(input_concat) return fake_image def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path).item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == i).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] if self.opt.data_type==16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num+1)) for i in np.unique(inst_np): label = i if i < 1000 else i//1000 idx = (inst == i).nonzero() num = idx.size()[0] idx = idx[num//2,:] val = np.zeros((1, feat_num+1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) if self.opt.data_type==16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
def initialize(self, opt): BaseModel.initialize(self, opt) # define tensors self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.input_base = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks if self.opt.conv3d: # one layer for considering a conv filter for each of the 26 channels self.netG_3d = networks.define_G_3d(opt.input_nc, opt.input_nc, norm=opt.norm, groups=opt.grps, gpu_ids=self.gpu_ids) # Generator of the GlyphNet self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) #Generator of the OrnaNet as an Encoder and a Decoder self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout1, self.gpu_ids) if self.opt.conditional: # not applicable for non-conditional case use_sigmoid = opt.no_lsgan if opt.which_model_preNet != 'none': self.preNet_A = networks.define_preNet(self.opt.input_nc_1+self.opt.output_nc_1, self.opt.input_nc_1+self.opt.output_nc_1, which_model_preNet=opt.which_model_preNet,norm=opt.norm, gpu_ids=self.gpu_ids) nif = opt.input_nc_1+opt.output_nc_1 netD_norm = opt.norm self.netD1 = networks.define_D(nif, opt.ndf, opt.which_model_netD, opt.n_layers_D, netD_norm, use_sigmoid, True, self.gpu_ids) if self.isTrain: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) if self.opt.print_weights: for key in self.netE1.state_dict().keys(): print key, 'random_init, mean,std:', torch.mean(self.netE1.state_dict()[key]),torch.std(self.netE1.state_dict()[key]) for key in self.netDE1.state_dict().keys(): print key, 'random_init, mean,std:', torch.mean(self.netDE1.state_dict()[key]),torch.std(self.netDE1.state_dict()[key]) if not self.isTrain: print "Load generators from their pretrained models..." if opt.no_Style2Glyph: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', opt.which_epoch) self.load_network(self.netG, 'G', opt.which_epoch) self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) else: if self.opt.conv3d: self.load_network(self.netG_3d, 'G_3d', str(int(opt.which_epoch)+int(opt.which_epoch1))) self.load_network(self.netG, 'G', str(int(opt.which_epoch)+int(opt.which_epoch1))) self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1))) self.load_network(self.netDE1, 'DE1', str(int(opt.which_epoch1))) self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1))) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) if self.isTrain: if opt.continue_train: print "Load StyleNet from its pretrained model..." self.load_network(self.netE1, 'E1', opt.which_epoch1) self.load_network(self.netDE1, 'DE1', opt.which_epoch1) self.load_network(self.netD1, 'D1', opt.which_epoch1) if opt.which_model_preNet != 'none': self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if self.isTrain: self.fake_AB1_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionL1 = torch.nn.L1Loss() self.criterionMSE = torch.nn.MSELoss() # initialize optimizers if self.opt.conv3d: self.optimizer_G_3d = torch.optim.Adam(self.netG_3d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.which_model_preNet != 'none': self.optimizer_preA = torch.optim.Adam(self.preNet_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') if self.opt.conv3d: networks.print_network(self.netG_3d) networks.print_network(self.netG) networks.print_network(self.netE1) networks.print_network(self.netDE1) if opt.which_model_preNet != 'none': networks.print_network(self.preNet_A) networks.print_network(self.netD1) print('-----------------------------------------------') self.initial = True
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # Parameters for WGAN self.use_which_gan = opt.use_which_gan # CycleGAN or CycleWGAN or ICycleWGAN self.wgan_clip_upper = opt.wgan_clip_upper self.wgan_clip_lower = opt.wgan_clip_lower self.wgan_n_critic = opt.wgan_n_critic self.wgan_optimizer = opt.wgan_optimizer # rmsprop self.wgan_train_critics = True # Not sure about this part # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_which_gan=self.use_which_gan, use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # L1 norm self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers if (self.use_which_gan == 'CycleWGAN'): if (self.wgan_optimizer == 'rmsprop'): self.optimizer_G = torch.optim.RMSprop(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.wgan_lrG) self.optimizer_D_A = torch.optim.RMSprop( self.netD_A.parameters(), lr=opt.wgan_lrD) self.optimizer_D_B = torch.optim.RMSprop( self.netD_B.parameters(), lr=opt.wgan_lrD) elif (self.use_which_gan == 'CycleGAN' or self.use_which_gan == 'ICycleWGAN'): self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_wasserstein(self, netD, real, fake): # Real pred_real = netD.forward(real) pred_fake = netD.forward(fake) loss_D = self.criterionGAN(pred_fake, pred_real, generator_loss=False) return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] # Backward for discriminator, wgan def backward_wgan_D(self, critic_iter): # D_A fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_wasserstein(self.netD_A, self.real_B, fake_B) # D_B fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_wasserstein(self.netD_B, self.real_A, fake_A) loss_D = (self.loss_D_A + self.loss_D_B) * 0.5 loss_D.backward(retain_variables=True) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() # Save all the data from the previous tensors self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def backward_wgan_G(self, do_backward=True): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # Wasserstein-GAN loss # G_A(A) self.fake_B = self.netG_A.forward(self.real_A) self.loss_G_A = self.criterionGAN(self.fake_B, generator_loss=True) # G_B(B) self.fake_A = self.netG_B.forward(self.real_B) self.loss_G_B = self.criterionGAN(self.fake_A, generator_loss=True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # Combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B if do_backward: # Backprop self.loss_G.backward() def optimize_parameters(self): # forward self.forward() if (self.use_which_gan == 'CycleGAN'): # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() # The changes here are that we need to add a bound for weights in the range [-c, c] elif (self.use_which_gan == 'CycleWGAN'): # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() for t in range(self.wgan_n_critic): # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # clip for p in self.netD_A.parameters(): p.data.clamp_(self.wgan_clip_lower, self.wgan_clip_upper) # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() # clip for p in self.netD_B.parameters(): p.data.clamp_(self.wgan_clip_lower, self.wgan_clip_upper) def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc self.count = 0 ##### define networks # Generator network netG_input_nc = input_nc # Main Generator with torch.no_grad(): self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval() self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval() self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval() self.G = networks.define_Refine(24, 3, self.gpu_ids).eval() self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() self.BCE = torch.nn.BCEWithLogitsLoss() # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc netB_input_nc = opt.output_nc * 2 # self.D1 = self.get_D(17, opt) # self.D2 = self.get_D(4, opt) # self.D3=self.get_D(7+3,opt) # self.D = self.get_D(20, opt) # self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.Unet, 'U', opt.which_epoch, pretrained_path) self.load_network(self.G1, 'G1', opt.which_epoch, pretrained_path) self.load_network(self.G2, 'G2', opt.which_epoch, pretrained_path) self.load_network(self.G, 'G', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) self.criterionStyle = networks.StyleLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print( '------------- Only training the local enhancer ork (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list))
class Pix2PixModel(BaseModel): ''' * @name: name * @description: return the name of this model * @return: the name of this model ''' def name(self): return 'Pix2PixModel' ''' * @name: initialize * @description: initialize the pix2pix model with the parameter set * @param opt: the configured parameter set ''' def initialize(self, opt): #initialize the base class with given parameter set opt BaseModel.initialize(self, opt) #get the type of the program(train or test) self.isTrain = opt.isTrain # load/define Generator self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type) #define the Discriminator if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) #deploy generator to device self.netG = self.netG.to(self.device) #deploy discriminator to device if self.isTrain: self.netD = self.netD.to(self.device) #if the program is for training if self.isTrain: #set the size of image buffer that stores previously generated images self.fake_AB_pool = ImagePool(opt.pool_size) #set initial learning rate for adam self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, device=self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device) # initialize optimizers self.schedulers = [] self.optimizers = [] #define the optimizer for generator self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) #define the optimizer for discriminator self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) #save the optimizers self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) #save schedulers for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') #get the input data set def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.input_A = input['A' if AtoB else 'B'].to(self.device) self.input_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if 'w' in input: self.input_w = input['w'] if 'h' in input: self.input_h = input['h'] #the forward function def forward(self): #get the input image self.real_A = self.input_A #generate the fake image by generator self.fake_B = self.netG(self.real_A) #get the groudtruth image self.real_B = self.input_B # no backprop gradients def test(self): with torch.no_grad(): self.forward() # get image paths def get_image_paths(self): return self.image_paths #backpropagate function for discriminator def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1).detach()) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real n = self.real_B.shape[1] loss_D_real_set = torch.empty(n, device=self.device) for i in range(n): sel_B = self.real_B[:, i, :, :].unsqueeze(1) real_AB = torch.cat((self.real_A, sel_B), 1) pred_real = self.netD(real_AB) loss_D_real_set[i] = self.criterionGAN(pred_real, True) #get the average all input groundtruth self.loss_D_real = torch.mean(loss_D_real_set) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 * self.opt.lambda_G self.loss_D.backward() #backpropagate function for generator def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.lambda_G # Second, G(A) = B n = self.real_B.shape[1] fake_B_expand = self.fake_B.expand(-1, n, -1, -1) L1 = torch.abs(fake_B_expand - self.real_B) L1 = L1.view(-1, n, self.real_B.shape[2] * self.real_B.shape[3]) L1 = torch.mean(L1, 2) min_L1, min_idx = torch.min(L1, 1) self.loss_G_L1 = torch.mean(min_L1) * self.opt.lambda_A self.min_idx = min_idx self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() #train discriminator self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() #train the generator self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.item()), ('G_L1', self.loss_G_L1.item()), ('D_real', self.loss_D_real.item()), ('D_fake', self.loss_D_fake.item())]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.detach()) fake_B = util.tensor2im(self.fake_B.detach()) if self.isTrain: sel_B = self.real_B[:, self.min_idx[0], :, :] else: sel_B = self.real_B[:, 0, :, :] real_B = util.tensor2im(sel_B.unsqueeze(1).detach()) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label) self.save_network(self.netD, 'D', label) def write_image(self, out_dir): image_numpy = self.fake_B.detach()[0][0].cpu().float().numpy() image_numpy = (image_numpy + 1) / 2.0 * 255.0 image_pil = Image.fromarray(image_numpy.astype(np.uint8)) image_pil = image_pil.resize((self.input_w[0], self.input_h[0]), Image.BICUBIC) name, _ = os.path.splitext(os.path.basename(self.image_paths[0])) out_path = os.path.join(out_dir, name + self.opt.suffix + '.png') image_pil.save(out_path)
class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [ l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f ] return loss_filter def get_G(self, in_C, out_c, n_blocks, opt, L=1, S=1): return networks.define_G(in_C, out_c, opt.ngf, opt.netG, L, S, opt.n_downsample_global, n_blocks, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) def get_D(self, inc, opt): netD = networks.define_D(inc, opt.ndf, opt.n_layers_D, opt.norm, opt.no_lsgan, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) return netD def cross_entropy2d(self, input, target, weight=None, size_average=True): n, c, h, w = input.size() nt, ht, wt = target.size() # Handle inconsistent size between input and target if h != ht or w != wt: input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) target = target.view(-1) loss = F.cross_entropy(input, target, weight=weight, size_average=size_average, ignore_index=250) return loss def ger_average_color(self, mask, arms): color = torch.zeros(arms.shape).cuda() for i in range(arms.shape[0]): count = len(torch.nonzero(mask[i, :, :, :])) if count < 10: color[i, 0, :, :] = 0 color[i, 1, :, :] = 0 color[i, 2, :, :] = 0 else: color[i, 0, :, :] = arms[i, 0, :, :].sum() / count color[i, 1, :, :] = arms[i, 1, :, :].sum() / count color[i, 2, :, :] = arms[i, 2, :, :].sum() / count return color def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc self.count = 0 ##### define networks # Generator network netG_input_nc = input_nc # Main Generator with torch.no_grad(): self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval() self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval() self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval() self.G = networks.define_Refine(24, 3, self.gpu_ids).eval() self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() self.BCE = torch.nn.BCEWithLogitsLoss() # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc netB_input_nc = opt.output_nc * 2 # self.D1 = self.get_D(17, opt) # self.D2 = self.get_D(4, opt) # self.D3=self.get_D(7+3,opt) # self.D = self.get_D(20, opt) # self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.Unet, 'U', opt.which_epoch, pretrained_path) self.load_network(self.G1, 'G1', opt.which_epoch, pretrained_path) self.load_network(self.G2, 'G2', opt.which_epoch, pretrained_path) self.load_network(self.G, 'G', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) self.criterionStyle = networks.StyleLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print( '------------- Only training the local enhancer ork (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) def encode_input(self, label_map, clothes_mask, all_clothes_label): size = label_map.size() oneHot_size = (size[0], 14, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() masked_label = masked_label.scatter_( 1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0) c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() c_label = c_label.scatter_(1, all_clothes_label.data.long().cuda(), 1.0) input_label = Variable(input_label) return input_label, masked_label, c_label def encode_input_test(self, label_map, label_map_ref, real_image_ref, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() input_label_ref = label_map_ref.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor( torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) input_label_ref = torch.cuda.FloatTensor( torch.Size(oneHot_size)).zero_() input_label_ref = input_label_ref.scatter_( 1, label_map_ref.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() input_label_ref = input_label_ref.half() input_label = Variable(input_label, volatile=infer) input_label_ref = Variable(input_label_ref, volatile=infer) real_image_ref = Variable(real_image_ref.data.cuda()) return input_label, input_label_ref, real_image_ref def discriminate(self, netD, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return netD.forward(fake_query) else: return netD.forward(input_concat) def gen_noise(self, shape): noise = np.zeros(shape, dtype=np.uint8) ### noise noise = cv2.randn(noise, 0, 255) noise = np.asarray(noise / 255, dtype=np.uint8) noise = torch.tensor(noise, dtype=torch.float32) return noise.cuda() def multi_scale_blend(self, fake_img, fake_c, mask, number=4): alpha = [0, 0.1, 0.3, 0.6, 0.9] smaller = mask out = 0 for i in range(1, number + 1): bigger = smaller smaller = morpho(smaller, 2, False) mid = bigger - smaller out += mid * (alpha[i] * fake_c + (1 - alpha[i]) * fake_img) out += smaller * fake_c out += (1 - mask) * fake_img return out def forward(self, label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, grid, mask_fore): # Encode Inputs input_label, masked_label, all_clothes_label = self.encode_input( label, clothes_mask, all_clothes_label) arm1_mask = torch.FloatTensor( (label.cpu().numpy() == 11).astype(np.float)).cuda() arm2_mask = torch.FloatTensor( (label.cpu().numpy() == 13).astype(np.float)).cuda() pre_clothes_mask = torch.FloatTensor( (pre_clothes_mask.detach().cpu().numpy() > 0.5).astype( np.float)).cuda() clothes = clothes * pre_clothes_mask shape = pre_clothes_mask.shape G1_in = torch.cat([ pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape) ], dim=1) arm_label = self.G1.refine(G1_in) arm_label = self.sigmoid(arm_label) CE_loss = self.cross_entropy2d(arm_label, (label * (1 - clothes_mask)).transpose( 0, 1)[0].long()) * 10 armlabel_map = generate_discrete_label(arm_label.detach(), 14, False) dis_label = generate_discrete_label(arm_label.detach(), 14) G2_in = torch.cat([ pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape) ], 1) fake_cl = self.G2.refine(G2_in) fake_cl = self.sigmoid(fake_cl) CE_loss += self.BCE(fake_cl, clothes_mask) * 10 fake_cl_dis = torch.FloatTensor( (fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() fake_cl_dis = morpho(fake_cl_dis, 1, True) new_arm1_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda() new_arm2_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda() fake_cl_dis = fake_cl_dis * (1 - new_arm1_mask) * (1 - new_arm2_mask) fake_cl_dis *= mask_fore arm1_occ = clothes_mask * new_arm1_mask arm2_occ = clothes_mask * new_arm2_mask bigger_arm1_occ = morpho(arm1_occ, 10) bigger_arm2_occ = morpho(arm2_occ, 10) arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask armlabel_map *= (1 - new_arm1_mask) armlabel_map *= (1 - new_arm2_mask) armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11 armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13 armlabel_map *= (1 - fake_cl_dis) dis_label = encode(armlabel_map, armlabel_map.shape) fake_c, warped, warped_mask, warped_grid = self.Unet( clothes, fake_cl_dis, pre_clothes_mask, grid) mask = fake_c[:, 3, :, :] mask = self.sigmoid(mask) * fake_cl_dis fake_c = self.tanh(fake_c[:, 0:3, :, :]) fake_c = fake_c * (1 - mask) + mask * warped skin_color = self.ger_average_color( (arm1_mask + arm2_mask - arm2_mask * arm1_mask), (arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image) occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask + clothes_mask)) * ( 1 - bigger_arm2_occ * (arm2_mask + arm1_mask + clothes_mask)) img_hole_hand = img_fore * (1 - clothes_mask) * occlude * (1 - fake_cl_dis) G_in = torch.cat([ img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape) ], 1) fake_image = self.G.refine(G_in.detach()) fake_image = self.tanh(fake_image) loss_D_fake = 0 loss_D_real = 0 loss_G_GAN = 0 loss_G_VGG = 0 L1_loss = 0 style_loss = L1_loss return [ self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real, loss_D_fake), fake_image, clothes, arm_label, L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid ] def inference(self, label, label_ref, image_ref): # Encode Inputs image_ref = Variable(image_ref) input_label, input_label_ref, real_image_ref = self.encode_input_test( Variable(label), Variable(label_ref), image_ref, infer=True) if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref) else: fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref) return fake_image def save(self, which_epoch): # self.save_network(self.Unet, 'U', which_epoch, self.gpu_ids) # self.save_network(self.G, 'G', which_epoch, self.gpu_ids) # self.save_network(self.G1, 'G1', which_epoch, self.gpu_ids) # self.save_network(self.G2, 'G2', which_epoch, self.gpu_ids) # # self.save_network(self.G3, 'G3', which_epoch, self.gpu_ids) # self.save_network(self.D, 'D', which_epoch, self.gpu_ids) # self.save_network(self.D1, 'D1', which_epoch, self.gpu_ids) # self.save_network(self.D2, 'D2', which_epoch, self.gpu_ids) # self.save_network(self.D3, 'D3', which_epoch, self.gpu_ids) pass # self.save_network(self.netB, 'B', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print( '------------ Now also finetuning global generator -----------' ) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0]) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize).cuda(device=opt.gpu_ids[0]) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) # If this is training phase if self.isTrain: use_sigmoid = opt.no_lsgan # do not use least square GAN by default self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) # If this is non-training phase/continue training phase if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: # build up so called history pool self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gpu_ids=opt.gpu_ids) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() if opt.use_prcp: self.criterionPrcp = networks.PrcpLoss(opt.weight_path, opt.bias_path, opt.perceptual_level, tensor=self.Tensor, gpu_ids=opt.gpu_ids) # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------')
def initialize(self, opt): self.opt = opt self.isTrain = opt.isTrain self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss) ### Encoder network # if self.gen_features: # self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', # opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks # if not self.isTrain or opt.continue_train or opt.load_pretrain: # pretrained_path = '' if not self.isTrain else opt.load_pretrain # self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) # if self.isTrain: # self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) # if self.gen_features: # self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions # self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) # self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = K.losses.MeanAbsoluteError() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss() # Names so we can breakout loss self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'] # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3,0): finetune_list = set() else: from sets import Set finetune_list = Set() # params_dict = dict(self.netG.named_parameters()) # params = [] # for key, value in params_dict.items(): # if key.startswith('model' + str(opt.n_local_enhancers)): # params += [value] # finetune_list.add(key.split('.')[0]) # print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) # print('The layers that are finetuned are ', sorted(finetune_list)) else: pass
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.batchSize = opt.batchSize self.fineSize = opt.fineSize # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize, opt.fineSize) if self.opt.rise_sobelLoss: self.sobelLambda = 0 else: self.sobelLambda = self.opt.lambda_sobel # load/define networks which_netG = opt.which_model_netG self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, which_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: self.D_channel = opt.input_nc + opt.output_nc use_sigmoid = opt.no_lsgan self.netD = networks.define_D(self.D_channel, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if not self.isTrain: self.netG.eval() if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions if self.opt.labelSmooth: self.criterionGAN = networks.GANLoss_smooth( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) networks.print_network(self.netD) print('-----------------------------------------------')
class ReCycleGANModel(BaseModel): def name(self): return 'ReCycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A0 = self.Tensor(nb, opt.input_nc, size, size) self.input_A1 = self.Tensor(nb, opt.input_nc, size, size) self.input_A2 = self.Tensor(nb, opt.input_nc, size, size) self.input_B0 = self.Tensor(nb, opt.output_nc, size, size) self.input_B1 = self.Tensor(nb, opt.output_nc, size, size) self.input_B2 = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.which_model_netP = opt.which_model_netP if opt.which_model_netP == 'prediction': self.netP_A = networks.define_G(opt.input_nc, opt.input_nc, opt.npf, opt.which_model_netP, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netP_B = networks.define_G(opt.output_nc, opt.output_nc, opt.npf, opt.which_model_netP, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) else: self.netP_A = networks.define_G(2 * opt.input_nc, opt.input_nc, opt.ngf, 'unet_128', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netP_B = networks.define_G(2 * opt.output_nc, opt.output_nc, opt.ngf, 'unet_128', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) self.load_network(self.netP_A, 'P_A', which_epoch) self.load_network(self.netP_B, 'P_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam( itertools.chain(self.netG_A.parameters(), self.netG_B.parameters(), self.netP_A.parameters(), self.netP_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netP_A) networks.print_network(self.netP_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A0 = input['A0'] input_A1 = input['A1'] input_A2 = input['A2'] input_B0 = input['B0'] input_B1 = input['B1'] input_B2 = input['B2'] self.input_A0.resize_(input_A0.size()).copy_(input_A0) self.input_A1.resize_(input_A1.size()).copy_(input_A1) self.input_A2.resize_(input_A2.size()).copy_(input_A2) self.input_B0.resize_(input_B0.size()).copy_(input_B0) self.input_B1.resize_(input_B1.size()).copy_(input_B1) self.input_B2.resize_(input_B2.size()).copy_(input_B2) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A0 = Variable(self.input_A0) self.real_A1 = Variable(self.input_A1) self.real_A2 = Variable(self.input_A2) self.real_B0 = Variable(self.input_B0) self.real_B1 = Variable(self.input_B1) self.real_B2 = Variable(self.input_B2) def test(self): real_A0 = Variable(self.input_A0, volatile=True) real_A1 = Variable(self.input_A1, volatile=True) fake_B0 = self.netG_A(real_A0) fake_B1 = self.netG_A(real_A1) # fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1),1)) if self.which_model_netP == 'prediction': fake_B2 = self.netP_B(fake_B0, fake_B1) else: fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1), 1)) self.rec_A = self.netG_B(fake_B2) self.fake_B0 = fake_B0 self.fake_B1 = fake_B1 self.fake_B2 = fake_B2 real_B0 = Variable(self.input_B0, volatile=True) real_B1 = Variable(self.input_B1, volatile=True) fake_A0 = self.netG_B(real_B0) fake_A1 = self.netG_B(real_B1) # fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1),1)) if self.which_model_netP == 'prediction': fake_A2 = self.netP_A(fake_A0, fake_A1) else: fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1), 1)) self.rec_B = self.netG_A(fake_A2) self.fake_A0 = fake_A0 self.fake_A1 = fake_A1 self.fake_A2 = fake_A2 # pred_A2 = self.netP_A(torch.cat((real_A0, real_A1),1)) if self.which_model_netP == 'prediction': pred_A2 = self.netP_A(real_A0, real_A1) else: pred_A2 = self.netP_A(torch.cat((real_A0, real_A1), 1)) self.pred_A2 = pred_A2 # pred_B2 = self.netP_B(torch.cat((real_B0, real_B1),1)) if self.which_model_netP == 'prediction': pred_B2 = self.netP_B(real_B0, real_B1) else: pred_B2 = self.netP_B(torch.cat((real_B0, real_B1), 1)) self.pred_B2 = pred_B2 # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B0 = self.fake_B_pool.query(self.fake_B0) loss_D_A0 = self.backward_D_basic(self.netD_A, self.real_B0, fake_B0) fake_B1 = self.fake_B_pool.query(self.fake_B1) loss_D_A1 = self.backward_D_basic(self.netD_A, self.real_B1, fake_B1) fake_B2 = self.fake_B_pool.query(self.fake_B2) loss_D_A2 = self.backward_D_basic(self.netD_A, self.real_B2, fake_B2) pred_B = self.fake_B_pool.query(self.pred_B2) loss_D_A3 = self.backward_D_basic(self.netD_A, self.real_B2, pred_B) self.loss_D_A = loss_D_A0 + loss_D_A1 + loss_D_A2 + loss_D_A3 def backward_D_B(self): fake_A0 = self.fake_A_pool.query(self.fake_A0) loss_D_B0 = self.backward_D_basic(self.netD_B, self.real_A0, fake_A0) fake_A1 = self.fake_A_pool.query(self.fake_A1) loss_D_B1 = self.backward_D_basic(self.netD_B, self.real_A1, fake_A1) fake_A2 = self.fake_A_pool.query(self.fake_A2) loss_D_B2 = self.backward_D_basic(self.netD_B, self.real_A2, fake_A2) pred_A = self.fake_A_pool.query(self.pred_A2) loss_D_B3 = self.backward_D_basic(self.netD_B, self.real_A2, pred_A) self.loss_D_B = loss_D_B0 + loss_D_B1 + loss_D_B2 + loss_D_B3 def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A0 = self.netG_A(self.real_B0) idt_A1 = self.netG_A(self.real_B1) loss_idt_A = (self.criterionIdt(idt_A0, self.real_B0) + self.criterionIdt( idt_A1, self.real_B1)) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B0 = self.netG_B(self.real_A0) idt_B1 = self.netG_B(self.real_A1) loss_idt_B = (self.criterionIdt(idt_B0, self.real_A0) + self.criterionIdt( idt_B1, self.real_A1)) * lambda_A * lambda_idt self.idt_A = idt_A0 self.idt_B = idt_B0 self.loss_idt_A = loss_idt_A self.loss_idt_B = loss_idt_B else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B0 = self.netG_A(self.real_A0) pred_fake = self.netD_A(fake_B0) loss_G_A0 = self.criterionGAN(pred_fake, True) fake_B1 = self.netG_A(self.real_A1) pred_fake = self.netD_A(fake_B1) loss_G_A1 = self.criterionGAN(pred_fake, True) # fake_B2 = self.netP_B(torch.cat((fake_B0,fake_B1),1)) if self.which_model_netP == 'prediction': fake_B2 = self.netP_B(fake_B0, fake_B1) else: fake_B2 = self.netP_B(torch.cat((fake_B0, fake_B1), 1)) pred_fake = self.netD_A(fake_B2) loss_G_A2 = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A0 = self.netG_B(self.real_B0) pred_fake = self.netD_B(fake_A0) loss_G_B0 = self.criterionGAN(pred_fake, True) fake_A1 = self.netG_B(self.real_B1) pred_fake = self.netD_B(fake_A1) loss_G_B1 = self.criterionGAN(pred_fake, True) # fake_A2 = self.netP_A(torch.cat((fake_A0,fake_A1),1)) if self.which_model_netP == 'prediction': fake_A2 = self.netP_A(fake_A0, fake_A1) else: fake_A2 = self.netP_A(torch.cat((fake_A0, fake_A1), 1)) pred_fake = self.netD_B(fake_A2) loss_G_B2 = self.criterionGAN(pred_fake, True) # prediction loss -- # pred_A2 = self.netP_A(torch.cat((self.real_A0, self.real_A1),1)) if self.which_model_netP == 'prediction': pred_A2 = self.netP_A(self.real_A0, self.real_A1) else: pred_A2 = self.netP_A(torch.cat((self.real_A0, self.real_A1), 1)) loss_pred_A = self.criterionCycle(pred_A2, self.real_A2) * lambda_A # pred_B2 = self.netP_B(torch.cat((self.real_B0, self.real_B1),1)) if self.which_model_netP == 'prediction': pred_B2 = self.netP_B(self.real_B0, self.real_B1) else: pred_B2 = self.netP_B(torch.cat((self.real_B0, self.real_B1), 1)) loss_pred_B = self.criterionCycle(pred_B2, self.real_B2) * lambda_B # Forward recycle loss rec_A = self.netG_B(fake_B2) loss_recycle_A = self.criterionCycle(rec_A, self.real_A2) * lambda_A # Backward recycle loss rec_B = self.netG_A(fake_A2) loss_recycle_B = self.criterionCycle(rec_B, self.real_B2) * lambda_B # Fwd cycle loss rec_A0 = self.netG_B(fake_B0) loss_cycle_A0 = self.criterionCycle(rec_A0, self.real_A0) * lambda_A rec_A1 = self.netG_B(fake_B1) loss_cycle_A1 = self.criterionCycle(rec_A1, self.real_A1) * lambda_A rec_B0 = self.netG_A(fake_A0) loss_cycle_B0 = self.criterionCycle(rec_B0, self.real_B0) * lambda_B rec_B1 = self.netG_A(fake_A1) loss_cycle_B1 = self.criterionCycle(rec_B1, self.real_B1) * lambda_B # combined loss loss_G = loss_G_A0 + loss_G_A1 + loss_G_A2 + loss_G_B0 + loss_G_B1 + loss_G_B2 + loss_recycle_A + loss_recycle_B + loss_pred_A + loss_pred_B + loss_idt_A + loss_idt_B + loss_cycle_A0 + loss_cycle_A1 + loss_cycle_B0 + loss_cycle_B1 loss_G.backward() self.fake_B0 = fake_B0 self.fake_B1 = fake_B1 self.fake_B2 = fake_B2 self.pred_B2 = pred_B2 self.fake_A0 = fake_A0 self.fake_A1 = fake_A1 self.fake_A2 = fake_A2 self.pred_A2 = pred_A2 self.rec_A = rec_A self.rec_B = rec_B self.loss_G_A = loss_G_A0 + loss_G_A1 + loss_G_A2 self.loss_G_B = loss_G_B0 + loss_G_B1 + loss_G_B2 self.loss_recycle_A = loss_recycle_A self.loss_recycle_B = loss_recycle_B self.loss_pred_A = loss_pred_A self.loss_pred_B = loss_pred_B self.loss_cycle_A = loss_cycle_A0 + loss_cycle_A1 self.loss_cycle_B = loss_cycle_B0 + loss_cycle_B1 def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict( [('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Recyc_A', self.loss_recycle_A), ('Pred_A', self.loss_pred_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Recyc_B', self.loss_recycle_B), ('Pred_B', self.loss_pred_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A0 = util.tensor2im(self.input_A0) real_A1 = util.tensor2im(self.input_A1) real_A2 = util.tensor2im(self.input_A2) fake_B0 = util.tensor2im(self.fake_B0) fake_B1 = util.tensor2im(self.fake_B1) fake_B2 = util.tensor2im(self.fake_B2) rec_A = util.tensor2im(self.rec_A) real_B0 = util.tensor2im(self.input_B0) real_B1 = util.tensor2im(self.input_B1) real_B2 = util.tensor2im(self.input_B2) fake_A0 = util.tensor2im(self.fake_A0) fake_A1 = util.tensor2im(self.fake_A1) fake_A2 = util.tensor2im(self.fake_A2) rec_B = util.tensor2im(self.rec_B) pred_A2 = util.tensor2im(self.pred_A2) pred_B2 = util.tensor2im(self.pred_B2) ret_visuals = OrderedDict([('real_A0', real_A0), ('fake_B0', fake_B0), ('real_A1', real_A1), ('fake_B1', fake_B1), ('fake_B2', fake_B2), ('rec_A', rec_A), ('real_A2', real_A2), ('real_B0', real_B0), ('fake_A0', fake_A0), ('real_B1', real_B1), ('fake_A1', fake_A1), ('fake_A2', fake_A2), ('rec_B', rec_B), ('real_B2', real_B2), ('real_A2', real_A2), ('pred_A2', pred_A2), ('real_B2', real_B2), ('pred_B2', pred_B2)]) if self.opt.isTrain and self.opt.identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) self.save_network(self.netP_A, 'P_A', label, self.gpu_ids) self.save_network(self.netP_B, 'P_B', label, self.gpu_ids)
class CycleGANcdModel(BaseModel): def name(self): return 'CycleGANcdModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> C -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> C -> B)') parser.add_argument( '--lambda_C', type=float, default=10.0, help='weight for cycle loss (C -> A -> C) and (C -> B -> C)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help= 'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'D_C_A', 'D_C_B', 'G_C', 'cycle_C', 'cycle_C', 'idt_C_A', 'idt_C_B' ] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_C = [ 'real_C', 'fake_C_A', 'fake_C_B', 'rec_C_A', 'rec_C_B' ] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') visual_names_C.append('idt_C_A') visual_names_C.append('idt_C_B') self.visual_names = visual_names_A + visual_names_B + visual_names_C # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = [ 'G_A', 'G_B', 'D_A', 'D_B', 'G_C_A', 'G_C_B', 'D_C' ] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B', 'G_C_A', 'G_C_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_C_A = networks.define_G(opt.input_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_C_B = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_C = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.fake_C_A_pool = ImagePool(opt.pool_size) self.fake_C_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters(), self.netG_C_A.parameters(), self.netG_C_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters(), self.netD_C.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.real_C = input['C'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): if self.isTrain: self.fake_C_A = self.netG_A(self.real_A) self.rec_A = self.netG_C_A(self.fake_C_A) self.fake_A = self.netG_C_A(self.real_C) self.rec_C_A = self.netG_A(self.fake_A) # self.real_C.detach_() self.fake_C_B = self.netG_B(self.real_B) self.rec_B = self.netG_C_B(self.fake_C_B) self.fake_B = self.netG_C_B(self.real_C) self.rec_C_B = self.netG_B(self.fake_B) # self.fake_B = self.netG_A(self.real_A) # self.rec_A = self.netG_B(self.fake_B) # self.fake_A = self.netG_B(self.real_B) # self.rec_B = self.netG_A(self.fake_A) else: self.fake_C_A = self.netG_A(self.real_A) self.rec_A = self.netG_C_A(self.fake_C_A) self.fake_A = self.netG_C_A(self.real_C) self.rec_C_A = self.netG_A(self.fake_A) # self.real_C.detach_() self.fake_C_B = self.netG_B(self.real_B) self.rec_B = self.netG_C_B(self.fake_C_B) self.fake_B = self.netG_C_B(self.fake_C_A) self.rec_C_B = self.netG_B(self.fake_B) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_D_C(self): fake_C_A = self.fake_C_A_pool.query(self.fake_C_A) self.loss_D_C_A = self.backward_D_basic(self.netD_C, self.real_C, fake_C_A) fake_C_B = self.fake_C_B_pool.query(self.fake_C_B) self.loss_D_C_B = self.backward_D_basic(self.netD_C, self.real_C, fake_C_B) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_C = self.opt.lambda_C # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_C) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_C) * lambda_C * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_C) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_C) * lambda_C * lambda_idt # G_C_A/B should be identity if real_A/B is fed. self.idt_C_A = self.netG_C_A(self.real_A) self.loss_idt_C_A = self.criterionIdt( self.idt_C_A, self.real_A) * lambda_A * lambda_idt / 2. self.idt_C_B = self.netG_C_B(self.real_B) self.loss_idt_C_B = self.criterionIdt( self.idt_C_B, self.real_B) * lambda_A * lambda_idt / 2. else: self.loss_idt_A = 0 self.loss_idt_B = 0 self.loss_idt_C_A = 0 self.loss_idt_C_B = 0 self.loss_idt = self.loss_idt_A + self.loss_idt_B + self.loss_idt_C_A + self.loss_idt_C_B # GAN loss D_A(G_A(A)) different from original code D_A for B and D_B for A, I use D_A for A, D_B for B, and D_C for C self.loss_G_C = ( self.criterionGAN(self.netD_C(self.fake_C_A), True) + self.criterionGAN(self.netD_C(self.fake_C_B), True)) / 2. # GAN loss D_B(G_B(B)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_A), True) # GAN loss D_C(G_C_A/B(C_A/B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_B), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # Backward cycle loss self.loss_cycle_C = ( self.criterionCycle(self.rec_C_A, self.real_C) + self.criterionCycle(self.rec_C_B, self.real_C)) * lambda_C / 2. # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_C + self.loss_cycle_A + self.loss_cycle_B + self.loss_cycle_C + self.loss_idt self.loss_G.backward() # if lambda_idt > 0: # # G_A should be identity if real_B is fed. # self.idt_A = self.netG_A(self.real_B) # self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # # G_B should be identity if real_A is fed. # self.idt_B = self.netG_B(self.real_A) # self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt # else: # self.loss_idt_A = 0 # self.loss_idt_B = 0 # # GAN loss D_A(G_A(A)) # self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # # GAN loss D_B(G_B(B)) # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # # Forward cycle loss # self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # # Backward cycle loss # self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # # combined loss # self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B # self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B, self.netD_C], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B, self.netD_C], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.backward_D_C() self.backward_D_A() self.backward_D_B() self.optimizer_D.step()
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A0 = self.Tensor(nb, opt.input_nc, size, size) self.input_A1 = self.Tensor(nb, opt.input_nc, size, size) self.input_A2 = self.Tensor(nb, opt.input_nc, size, size) self.input_B0 = self.Tensor(nb, opt.output_nc, size, size) self.input_B1 = self.Tensor(nb, opt.output_nc, size, size) self.input_B2 = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.which_model_netP = opt.which_model_netP if opt.which_model_netP == 'prediction': self.netP_A = networks.define_G(opt.input_nc, opt.input_nc, opt.npf, opt.which_model_netP, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netP_B = networks.define_G(opt.output_nc, opt.output_nc, opt.npf, opt.which_model_netP, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) else: self.netP_A = networks.define_G(2 * opt.input_nc, opt.input_nc, opt.ngf, 'unet_128', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netP_B = networks.define_G(2 * opt.output_nc, opt.output_nc, opt.ngf, 'unet_128', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) self.load_network(self.netP_A, 'P_A', which_epoch) self.load_network(self.netP_B, 'P_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam( itertools.chain(self.netG_A.parameters(), self.netG_B.parameters(), self.netP_A.parameters(), self.netP_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netP_A) networks.print_network(self.netP_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------')
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) self.load_network(self.AE, 'AE', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionAE = torch.nn.MSELoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE = torch.optim.Adam(self.AE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE_GA_GB = torch.optim.Adam( itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) networks.print_network(self.netD_A) networks.print_network(self.netD_B) networks.print_network(self.AE) print('-----------------------------------------------')
class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG resnet_9blocks' ResNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, opt.upsample_conv_type, opt.conv_dilation_G, opt.upsample_conv_dilation_G, opt.resnet_activation_G, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, opt.upsample_conv_type, opt.conv_dilation_G, opt.upsample_conv_dilation_G, opt.resnet_activation_G, self.gpu_ids) if self.isTrain: # define discriminators self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update D_A and D_B's weights
class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
class T2NetModel(BaseModel): def name(self): return 'T2Net model' def initialize(self, opt, labeled_dataset=None, unlabeled_dataset=None): BaseModel.initialize(self, opt) self.loss_names = [ 'img_rec', 'img_G', 'img_D', 'lab_s', 'lab_t', 'f_G', 'f_D', 'lab_smooth' ] self.visual_names = [ 'img_s', 'img_t', 'lab_s', 'lab_t', 'img_s2t', 'img_t2t', 'lab_s_g', 'lab_t_g' ] if self.isTrain: self.model_names = ['img2task', 's2t', 'img_D', 'f_D'] else: self.model_names = ['img2task', 's2t'] # define the transform network self.net_s2t = network.define_G(opt.image_nc, opt.image_nc, opt.ngf, opt.transform_layers, opt.norm, opt.activation, opt.trans_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) # define the task network self.net_img2task = network.define_G(opt.image_nc, opt.label_nc, opt.ngf, opt.task_layers, opt.norm, opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) # define the discriminator if self.isTrain: self.net_img_D = network.define_D(opt.image_nc, opt.ndf, opt.image_D_layers, opt.num_D, opt.norm, opt.activation, opt.init_type, opt.gpu_ids) self.net_f_D = network.define_featureD(opt.image_feature, opt.feature_D_layers, opt.norm, opt.activation, opt.init_type, opt.gpu_ids) if self.isTrain: self.fake_img_pool = ImagePool(opt.pool_size) # define loss functions self.l1loss = torch.nn.L1Loss() self.nonlinearity = torch.nn.ReLU() # initialize optimizers self.optimizer_T2Net = torch.optim.Adam([{ 'params': filter(lambda p: p.requires_grad, self.net_s2t.parameters()) }, { 'params': filter(lambda p: p.requires_grad, self.net_img2task.parameters()), 'lr': opt.lr_task, 'betas': (0.95, 0.999) }], lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizer_D = torch.optim.Adam(itertools.chain( filter(lambda p: p.requires_grad, self.net_img_D.parameters()), filter(lambda p: p.requires_grad, self.net_f_D.parameters())), lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_T2Net) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(network.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) # initializing GPstruct if self.isTrain and opt.gp: self.labeled_dataset = labeled_dataset self.unlabeled_dataset = unlabeled_dataset self.gp_struct = GPStruct(num_lbl=len(labeled_dataset), num_unlbl=len(unlabeled_dataset), train_batch_size=self.opt.batch_size, version=self.opt.version, kernel_type=self.opt.kernel_type, pre_trained_enc=opt.pre_trained_enc, img_size=opt.load_size) def set_input(self, input): self.input = input self.img_source = input['img_source'].cuda(self.gpu_ids[0]) self.img_target = input['img_target'].cuda(self.gpu_ids[0]) if self.isTrain: self.lab_source = input['lab_source'].cuda(self.gpu_ids[0]) self.lab_target = input['lab_target'].cuda(self.gpu_ids[0]) # if len(self.gpu_ids) > 0: # self.img_source = self.img_source.cuda(self.gpu_ids[0], async=True) # self.img_target = self.img_target.cuda(self.gpu_ids[0], async=True) # if self.isTrain: # self.lab_source = self.lab_source.cuda(self.gpu_ids[0], async=True) # self.lab_target = self.lab_target.cuda(self.gpu_ids[0], async=True) def forward(self): self.img_s = Variable(self.img_source) self.img_t = Variable(self.img_target) self.lab_s = Variable(self.lab_source) self.lab_t = Variable(self.lab_target) def backward_D_basic(self, netD, real, fake): D_loss = 0 for (real_i, fake_i) in zip(real, fake): # Real D_real = netD(real_i.detach()) # fake D_fake = netD(fake_i.detach()) for (D_real_i, D_fake_i) in zip(D_real, D_fake): D_loss += (torch.mean((D_real_i - 1.0)**2) + torch.mean( (D_fake_i - 0.0)**2)) * 0.5 D_loss.backward() return D_loss def backward_D_image(self): network._freeze(self.net_s2t, self.net_img2task, self.net_f_D) network._unfreeze(self.net_img_D) size = len(self.img_s2t) fake = [] for i in range(size): fake.append(self.fake_img_pool.query(self.img_s2t[i])) real = task.scale_pyramid(self.img_t, size) self.loss_img_D = self.backward_D_basic(self.net_img_D, real, fake) def backward_D_feature(self): network._freeze(self.net_s2t, self.net_img2task, self.net_img_D) network._unfreeze(self.net_f_D) self.loss_f_D = self.backward_D_basic(self.net_f_D, [self.lab_f_t], [self.lab_f_s]) def foreward_G_basic(self, net_G, img_s, img_t): img = torch.cat([img_s, img_t], 0) fake = net_G(img) size = len(fake) f_s, f_t = fake[0].chunk(2) img_fake = fake[1:] img_s_fake = [] img_t_fake = [] for img_fake_i in img_fake: img_s, img_t = img_fake_i.chunk(2) img_s_fake.append(img_s) img_t_fake.append(img_t) return img_s_fake, img_t_fake, f_s, f_t, size def backward_synthesis2real(self): # image to image transform network._freeze(self.net_img2task, self.net_img_D, self.net_f_D) network._unfreeze(self.net_s2t) self.img_s2t, self.img_t2t, self.img_f_s, self.img_f_t, size = \ self.foreward_G_basic(self.net_s2t, self.img_s, self.img_t) # image GAN loss and reconstruction loss img_real = task.scale_pyramid(self.img_t, size - 1) G_loss = 0 rec_loss = 0 for i in range(size - 1): rec_loss += self.l1loss(self.img_t2t[i], img_real[i]) D_fake = self.net_img_D(self.img_s2t[i]) for D_fake_i in D_fake: G_loss += torch.mean((D_fake_i - 1.0)**2) self.loss_img_G = G_loss * self.opt.lambda_gan_img self.loss_img_rec = rec_loss * self.opt.lambda_rec_img total_loss = self.loss_img_G + self.loss_img_rec total_loss.backward(retain_graph=True) def backward_translated2depth(self): # task network network._freeze(self.net_img_D, self.net_f_D) network._unfreeze(self.net_s2t, self.net_img2task) fake = self.net_img2task.forward(self.img_s2t[-1]) size = len(fake) self.lab_f_s = fake[0] self.lab_s_g = fake[1:] #feature GAN loss D_fake = self.net_f_D(self.lab_f_s) G_loss = 0 for D_fake_i in D_fake: G_loss += torch.mean((D_fake_i - 1.0)**2) self.loss_f_G = G_loss * self.opt.lambda_gan_feature # task loss lab_real = task.scale_pyramid(self.lab_s, size - 1) task_loss = 0 for (lab_fake_i, lab_real_i) in zip(self.lab_s_g, lab_real): task_loss += self.l1loss(lab_fake_i, lab_real_i) self.loss_lab_s = task_loss * self.opt.lambda_rec_lab total_loss = self.loss_f_G + self.loss_lab_s total_loss.backward() def backward_real2depth(self): # image2depth network._freeze(self.net_s2t, self.net_img_D, self.net_f_D) network._unfreeze(self.net_img2task) fake = self.net_img2task.forward(self.img_t) size = len(fake) # Gan depth self.lab_f_t = fake[0] self.lab_t_g = fake[1:] img_real = task.scale_pyramid(self.img_t, size - 1) self.loss_lab_smooth = task.get_smooth_weight( self.lab_t_g, img_real, size - 1) * self.opt.lambda_smooth total_loss = self.loss_lab_smooth total_loss.backward() def optimize_parameters(self, epoch_iter): self.forward() # T2Net self.optimizer_T2Net.zero_grad() self.backward_synthesis2real() self.backward_translated2depth() self.backward_real2depth() self.optimizer_T2Net.step() # Discriminator self.optimizer_D.zero_grad() self.backward_D_feature() self.backward_D_image() # self.optimizer_D.step() # for p in self.net_f_D.parameters(): # p.data.clamp_(-0.01,0.01) if epoch_iter % 5 == 0: self.optimizer_D.step() for p in self.net_f_D.parameters(): p.data.clamp_(-0.01, 0.01) def validation_target(self): lab_real = task.scale_pyramid(self.lab_t, len(self.lab_t_g)) task_loss = 0 for (lab_fake_i, lab_real_i) in zip(self.lab_t_g, lab_real): task_loss += task.rec_loss(lab_fake_i, lab_real_i) self.loss_lab_t = task_loss * self.opt.lambda_rec_lab def generate_fmaps_GP(self): self.gp_struct.gen_featmaps(self.labeled_dataset, self.net_img2task, self.device) self.gp_struct.gen_featmaps_unlbl(self.unlabeled_dataset, self.net_img2task, self.device) def optimize_parameters_GP(self, iter, data): input_im = data['img_target'].cuda(self.gpu_ids[0]) # gt = data['lab_target'].cuda(self.device) imgid = data['img_target_paths'] self.optimizer_T2Net.zero_grad() network._freeze(self.net_s2t, self.net_img_D, self.net_f_D) network._unfreeze(self.net_img2task) self.net_img2task.train() ### center in # outputs = self.netTask(input_im) # zy_in = outputs[0] ### center_out _, zy_in = self.net_img2task(input_im, gp=True) loss_gp = self.gp_struct.compute_gploss(zy_in, imgid, iter, 0) self.loss_gp = loss_gp * self.opt.lambda_gp self.loss_gp.backward() self.optimizer_T2Net.step()
netD_A_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, opt.finesize, opt.input_nc) # create discriminator B train function netD_B_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, opt.finesize, opt.input_nc) # train loop time_start = time.time() how_many_epochs = 5 iteration_count = 0 epoch_count = 0 batch_size = opt.batch_size display_freq = 10000 netG_A_function = get_generater_function(netG_A) netG_B_functionr = get_generater_function(netG_B) fake_A_pool = ImagePool() fake_B_pool = ImagePool() while epoch_count < how_many_epochs: target_label = np.zeros((batch_size, 1)) epoch_count, A, B = next(train_batch) tmp_fake_B = netG_A_function([A])[0] tmp_fake_A = netG_B_functionr([B])[0] _fake_B = fake_B_pool.query(tmp_fake_B) _fake_A = fake_A_pool.query(tmp_fake_A) netG_train_function.train_on_batch([A, B], target_label) netD_B_train_function.train_on_batch([B, _fake_B], target_label)
def initialize(self, opt, labeled_dataset=None, unlabeled_dataset=None): BaseModel.initialize(self, opt) self.loss_names = [ 'img_rec', 'img_G', 'img_D', 'lab_s', 'lab_t', 'f_G', 'f_D', 'lab_smooth' ] self.visual_names = [ 'img_s', 'img_t', 'lab_s', 'lab_t', 'img_s2t', 'img_t2t', 'lab_s_g', 'lab_t_g' ] if self.isTrain: self.model_names = ['img2task', 's2t', 'img_D', 'f_D'] else: self.model_names = ['img2task', 's2t'] # define the transform network self.net_s2t = network.define_G(opt.image_nc, opt.image_nc, opt.ngf, opt.transform_layers, opt.norm, opt.activation, opt.trans_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) # define the task network self.net_img2task = network.define_G(opt.image_nc, opt.label_nc, opt.ngf, opt.task_layers, opt.norm, opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate, False, opt.gpu_ids, opt.U_weight) # define the discriminator if self.isTrain: self.net_img_D = network.define_D(opt.image_nc, opt.ndf, opt.image_D_layers, opt.num_D, opt.norm, opt.activation, opt.init_type, opt.gpu_ids) self.net_f_D = network.define_featureD(opt.image_feature, opt.feature_D_layers, opt.norm, opt.activation, opt.init_type, opt.gpu_ids) if self.isTrain: self.fake_img_pool = ImagePool(opt.pool_size) # define loss functions self.l1loss = torch.nn.L1Loss() self.nonlinearity = torch.nn.ReLU() # initialize optimizers self.optimizer_T2Net = torch.optim.Adam([{ 'params': filter(lambda p: p.requires_grad, self.net_s2t.parameters()) }, { 'params': filter(lambda p: p.requires_grad, self.net_img2task.parameters()), 'lr': opt.lr_task, 'betas': (0.95, 0.999) }], lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizer_D = torch.optim.Adam(itertools.chain( filter(lambda p: p.requires_grad, self.net_img_D.parameters()), filter(lambda p: p.requires_grad, self.net_f_D.parameters())), lr=opt.lr_trans, betas=(0.5, 0.9)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_T2Net) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(network.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) # initializing GPstruct if self.isTrain and opt.gp: self.labeled_dataset = labeled_dataset self.unlabeled_dataset = unlabeled_dataset self.gp_struct = GPStruct(num_lbl=len(labeled_dataset), num_unlbl=len(unlabeled_dataset), train_batch_size=self.opt.batch_size, version=self.opt.version, kernel_type=self.opt.kernel_type, pre_trained_enc=opt.pre_trained_enc, img_size=opt.load_size)
def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ### Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: if self.opt.verbose: print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params':[value],'lr':opt.lr}] else: params += [{'params':[value],'lr':0.0}] else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def initialize(self, opt): #initialize the base class with given parameter set opt BaseModel.initialize(self, opt) #get the type of the program(train or test) self.isTrain = opt.isTrain # load/define Generator self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type) #define the Discriminator if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) #deploy generator to device self.netG = self.netG.to(self.device) #deploy discriminator to device if self.isTrain: self.netD = self.netD.to(self.device) #if the program is for training if self.isTrain: #set the size of image buffer that stores previously generated images self.fake_AB_pool = ImagePool(opt.pool_size) #set initial learning rate for adam self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, device=self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device) # initialize optimizers self.schedulers = [] self.optimizers = [] #define the optimizer for generator self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) #define the optimizer for discriminator self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) #save the optimizers self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) #save schedulers for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------')
class PairModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.opt = opt self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) if opt.vgg > 0: self.vgg_loss = networks.PerceptualLoss() self.vgg_loss.cuda() self.vgg = networks.load_vgg16("./model") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) skip = True if opt.skip > 0 else False self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions if opt.use_wgan: self.criterionGAN = networks.DiscLossWGANGP() else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if opt.use_mse: self.criterionCycle = torch.nn.MSELoss() else: self.criterionCycle = torch.nn.L1Loss() self.criterionL1 = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) if opt.isTrain: self.netG_A.train() self.netG_B.train() else: self.netG_A.eval() self.netG_B.eval() print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) if self.opt.skip == 1: self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) else: self.rec_B = self.netG_A.forward(self.fake_A) def predict(self): self.real_A = Variable(self.input_A, volatile=True) # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) if self.opt.skip == 1: latent_real_A = util.tensor2im(self.latent_real_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A), ("rec_A", rec_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)]) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) if self.opt.use_wgan: loss_D_real = pred_real.mean() else: loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) if self.opt.use_wgan: loss_D_fake = pred_fake.mean() else: loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss if self.opt.use_wgan: loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty( netD, real.data, fake.data) else: loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. if self.opt.skip == 1: self.idt_A, _ = self.netG_A.forward(self.real_B) else: self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) if self.opt.skip == 1: self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) else: self.fake_B = self.netG_A.forward(self.real_A) # = self.latent_real_A + self.opt.skip * self.real_A pred_fake = self.netD_A.forward(self.fake_B) if self.opt.use_wgan: self.loss_G_A = -pred_fake.mean() else: self.loss_G_A = self.criterionGAN(pred_fake, True) self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1 if self.opt.use_wgan: self.loss_G_B = -pred_fake.mean() else: self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss if lambda_A > 0: self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A else: self.loss_cycle_A = 0 # Backward cycle loss # = self.latent_fake_A + self.opt.skip * self.fake_A if lambda_B > 0: if self.opt.skip == 1: self.rec_B, self.latent_fake_A = self.netG_A.forward( self.fake_A) else: self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B else: self.loss_cycle_B = 0 self.loss_vgg_a = self.vgg_loss.compute_vgg_loss( self.vgg, self.fake_A, self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0 self.loss_vgg_b = self.vgg_loss.compute_vgg_loss( self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \ self.loss_vgg_a + self.loss_vgg_b + \ self.loss_idt_A + self.loss_idt_B # self.loss_G = self.L1_AB + self.L1_BA self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] L1 = (self.L1_AB + self.L1_BA).data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0] ) / self.opt.vgg if self.opt.vgg > 0 else 0 if self.opt.identity > 0: idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0] if self.opt.lambda_A > 0.0: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg), ("idt", idt)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg), ("idt", idt)) else: if self.opt.lambda_A > 0.0: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg)) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) if self.opt.skip > 0: latent_real_A = util.tensor2im(self.latent_real_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) if self.opt.identity > 0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) if self.opt.lambda_A > 0.0: rec_A = util.tensor2im(self.rec_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.skip > 0: latent_fake_A = util.tensor2im(self.latent_fake_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ("idt_A", idt_A), ("idt_B", idt_B)]) else: if self.opt.skip > 0: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) else: if self.opt.lambda_A > 0.0: rec_A = util.tensor2im(self.rec_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.skip > 0: latent_fake_A = util.tensor2im(self.latent_fake_A.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) else: if self.opt.skip > 0: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('real_B', real_B), ('fake_A', fake_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), ('fake_A', fake_A)]) def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) def update_learning_rate(self): if self.opt.new_lr: lr = self.old_lr / 2 else: lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class ObjectVariedGANModel(BaseModel): def name(self): return 'ObjectVariedGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(no_dropout=True) parser.add_argument('--set_order', type=str, default='decreasing', help='order of segmentation') parser.add_argument('--ins_max', type=int, default=1, help='maximum number of object to forward') parser.add_argument('--ins_per', type=int, default=1, help='number of object to forward, for one pass') if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_idt', type=float, default=1.0, help='use identity mapping. Setting lambda_idt other than 0 has an effect of scaling the weight of the identity mapping loss') parser.add_argument('--lambda_ctx', type=float, default=1.0, help='use context preserving. Setting lambda_ctx other than 0 has an effect of scaling the weight of the context preserving loss') parser.add_argument('--lambda_fs', type=float, default=10.0, help='use feature similarity. Setting lambda_fs other than 0 has an effect of scaling the weight of the feature similarity loss') return parser def initialize(self, opt): BaseModel.initialize(self, opt) self.ins_iter = self.opt.ins_max // self.opt.ins_per # number of forward iteration, self.ins_iter=4//2,所以self.ins_iter=2 # “//”,在python中,整数除法,这个叫“地板除”,3//2=1 # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cyc_A', 'idt_A', 'ctx_A', 'fs_A', 'D_B', 'G_B', 'cyc_B', 'idt_B', 'ctx_B', 'fs_B'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A_img = ['real_A_img', 'fake_B_img', 'rec_A_img'] visual_names_B_img = ['real_B_img', 'fake_A_img', 'rec_B_img'] visual_names_A_seg = ['real_A_seg', 'fake_B_seg', 'rec_A_seg'] visual_names_B_seg = ['real_B_seg', 'fake_A_seg', 'rec_B_seg'] self.visual_names = visual_names_A_img + visual_names_A_seg + visual_names_B_img + visual_names_B_seg # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: #isTrain:True时表示是执行了train.py,否则执行了test.py self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] #isTrain为True时,保存生成器和判别器 else: self.model_names = ['G_A', 'G_B'] #isTrain为False时,只保存生成器 # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.ins_per, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) # opt.norm默认是'instance' self.netG_B = networks.define_G(opt.output_nc, opt.ins_per, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ins_per, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ins_per, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) # '--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images' self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) # 通过opt.no_lsgan控制,使用MSEloss或者BSEloss self.criterionCyc = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # 以下初始化optimizer涉及两个函数,filter()和lambda # filter() 函数 # 用于过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表。 # 该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。python3中filter返回迭代器对象 # lambda p: p.requires_grad # 这里匿名函数,p是参数,p.required_grad是表达式 # initialize optimizers # 这里的filter,第一个为函数(匿名函数),第二个为序列(包含netG_A和netG_B的所有parameter),返回这些parameter中符合requires_grad=True的parameter。 # 相当于,网络中所有参数,只有当requires_grad为True的时候,该参数才传给Adam() self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(self.netD_A.parameters(), self.netD_B.parameters())), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def select_masks(self, segs_batch): """Select object masks to use""" if self.opt.set_order == 'decreasing': return self.select_masks_decreasing(segs_batch) elif self.opt.set_order == 'random': return self.select_masks_random(segs_batch) else: raise NotImplementedError('Set order name [%s] is not recognized' % self.opt.set_order) def select_masks_decreasing(self, segs_batch): """Select masks in decreasing order""" ret = list() for segs in segs_batch: mean = segs.mean(-1).mean(-1) # mean的size是torch.Size([20]) # 这里做了两次mean处理,都是在最后一维进行处理, m, i = mean.topk(self.opt.ins_max) # m是:tensor([-0.7352, -0.7675, -1.0000, -1.0000]),大小是torch.Size([4])。 # i是tensor([0, 1, 5, 3]),大小是torch.Size([4]),i可能表示前四个大的seg的索引 # '--ins_max', type=int, default=4, help='maximum number of object to forward' ret.append(segs[i, :, :]) # ret是list,其中每个元素shape是torch.Size([4, 200, 200]) return torch.stack(ret) # torch.stack表示在新的dim上concatenate。 # 返回的是torch.Size([1, 4, 200, 200]) def select_masks_random(self, segs_batch): """Select masks in random order""" ret = list() for segs in segs_batch: mean = (segs + 1).mean(-1).mean(-1) # torch.Size([20]) m, i = mean.topk(self.opt.ins_max) num = min(len(mean.nonzero()), self.opt.ins_max) # num = {int}2 reorder = np.concatenate((np.random.permutation(num), np.arange(num, self.opt.ins_max))) # reorder = {ndarry}[0 1 2 3] ret.append(segs[i[reorder], :, :]) # ret是list,其中每个元素shape是torch.Size([4, 200, 200]) return torch.stack(ret) def merge_masks(self, segs): """Merge masks (B, N, W, H) -> (B, 1, W, H)""" ret = torch.sum((segs + 1)/2, dim=1, keepdim=True) # (B, 1, W, H) return ret.clamp(max=1, min=0) * 2 - 1 def get_weight_for_ctx(self, x, y): """Get weight for context preserving loss""" z = self.merge_masks(torch.cat([x, y], dim=1)) return (1 - z) / 2 # [-1,1] -> [1,0] def weighted_L1_loss(self, src, tgt, weight): """L1 loss with given weight (used for context preserving loss)""" return torch.mean(weight * torch.abs(src - tgt)) def get_weight_for_cx(self, x, y): """Get weight for context preserving loss""" z = self.merge_masks(torch.cat([x, y], dim=1)) return (1 - z) / 2 # [-1,1] -> [1,0] def multiply_cx(self, src, weight): """L1 loss with given weight (used for context preserving loss)""" return torch.mean(weight * torch.abs(src)) def split(self, x): """Split data into image and mask (only assume 3-channel image)""" return x[:, :3, :, :], x[:, 3:, :, :] # 前三通道是image的,剩余通道是mask的 # input是数据集实例(类UnalignedSegDataset的实例) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' # input is the datasets, we use input[idx]to get the item. # eg.input['A'] or input['B'] or input['A_segs'] or input['B_segs'] # refer to the "data/unaligned_seg_dataset.py' and see the get_item return the map data self.real_A_img = input['A' if AtoB else 'B'].to(self.device) # self.real_A_img的shape是torch.Size([1, 3, 256, 256]),一张原图,3通道 self.real_B_img = input['B' if AtoB else 'A'].to(self.device) real_A_segs = input['A_segs' if AtoB else 'B_segs'] # real_A_segs是domainA(当AtoB时)中的一张图对应的多张segs,所有segs拼接使用cat函数 real_B_segs = input['B_segs' if AtoB else 'A_segs'] self.real_A_segs = self.select_masks(real_A_segs).to(self.device) # self.real_A_segs的shape是torch.Size([1, 4, 200, 200]),四张seg self.real_B_segs = self.select_masks(real_B_segs).to(self.device) self.real_A = torch.cat([self.real_A_img, self.real_A_segs], dim=1) # self.real_A的shape是torch.Size([1, 7, 200, 200]),融合了一张原图和四张seg self.real_B = torch.cat([self.real_B_img, self.real_B_segs], dim=1) self.real_A_seg = self.merge_masks(self.real_A_segs) # merged mask,Merge masks (B, N, W, H) -> (B, 1, W, H)# self.real_A_seg的shape是torch.Size([1, 1, 200, 200]),相当于将其压缩,将7压缩为1 self.real_B_seg = self.merge_masks(self.real_B_segs) self.image_paths = input['A_paths' if AtoB else 'B_paths'] # A_paths是一个list,但是其长度为1,值为'./datasets/shp2gir_coco/trainA/788.png' def forward(self, idx=0): N = self.opt.ins_per # '--ins_per', type=int, default=2, help='number of object to forward, for one pass') # 一次迭代中,使用到的object的数目 self.real_A_seg_sng = self.real_A_segs[:, N*idx:N*(idx+1), :, :] # ith mask,似乎取第i批mask,一批有ins_iter张(2张)。sng应该表示single的意思。 # self.real_A_segs的shape是torch.Size([1, 4, 200, 200]),四张seg self.real_B_seg_sng = self.real_B_segs[:, N*idx:N*(idx+1), :, :] # ith mask empty = -torch.ones(self.real_A_seg_sng.size()).to(self.device) # empty image self.forward_A = (self.real_A_seg_sng + 1).sum() > 0 # check if there are remaining object # 当forward_A=1时,才前馈并进反向传播 # 因为在read_segs()中若seg不存在,则每个像素设置为-1。所以这里(self.real_A_seg_sng + 1)? self.forward_B = (self.real_B_seg_sng + 1).sum() > 0 # check if there are remaining object # forward A if self.forward_A: self.real_A_fuse_sng = torch.cat([self.real_A_img_sng, self.real_A_seg_sng], dim=1) self.fake_B_fuse_sng = self.netG_A(self.real_A_fuse_sng) # (原图image和掩码)即(self.real_A_sng)作为一个整体输入到生成器 self.fake_B_img_sng, self.fake_B_seg_sng = self.split(self.fake_B_fuse_sng) self.rec_A_fuse_sng = self.netG_B(self.fake_B_fuse_sng) # 生成的假的domain B的图(self.fake_B_sng),再输入到G_B进行reconstruc self.rec_A_img_sng, self.rec_A_seg_sng = self.split(self.rec_A_fuse_sng) self.fake_B_seg_mul = self.fake_B_seg_sng self.fake_B_mul = self.fake_B_fuse_sng # self.fake_B_mul是假的domainB的结果,用于计算loss # forward B if self.forward_B: self.real_B_fuse_sng = torch.cat([self.real_B_img_sng, self.real_B_seg_sng], dim=1) self.fake_A_fuse_sng = self.netG_B(self.real_B_fuse_sng) self.fake_A_img_sng, self.fake_A_seg_sng = self.split(self.fake_A_fuse_sng) self.rec_B_fuse_sng = self.netG_A(self.fake_A_fuse_sng) self.rec_B_img_sng, self.rec_B_seg_sng = self.split(self.rec_B_fuse_sng) self.fake_A_seg_mul = self.fake_A_seg_sng self.fake_A_mul = self.fake_A_fuse_sng def test(self): # 用于test.py # init setting # 与optimize_parameters()相同的初始化 self.real_A_img_sng = self.real_A_img # self.real_A_img的shape是torch.Size([1, 3, 200, 200]),一张原图,3通道 self.real_B_img_sng = self.real_B_img self.fake_A_seg_list = list() self.fake_B_seg_list = list() self.rec_A_seg_list = list() self.rec_B_seg_list = list() # sequential mini-batch translation for i in range(self.ins_iter): # forward with torch.no_grad(): # no grad,注意!test的时候没有更新参数,所以forward的时候设置:no grad self.forward(i) # update setting for next iteration self.real_A_img_sng = self.fake_B_img_sng.detach() self.real_B_img_sng = self.fake_A_img_sng.detach() self.fake_A_seg_list.append(self.fake_A_seg_sng.detach()) self.fake_B_seg_list.append(self.fake_B_seg_sng.detach()) self.rec_A_seg_list.append(self.rec_A_seg_sng.detach()) self.rec_B_seg_list.append(self.rec_B_seg_sng.detach()) # save visuals if i == 0: # first self.rec_A_img = self.rec_A_img_sng self.rec_B_img = self.rec_B_img_sng if i == self.ins_iter - 1: # last self.fake_A_img = self.fake_A_img_sng self.fake_B_img = self.fake_B_img_sng self.fake_A_seg = self.merge_masks(self.fake_A_seg_mul) self.fake_B_seg = self.merge_masks(self.fake_B_seg_mul) self.rec_A_seg = self.merge_masks(torch.cat(self.rec_A_seg_list, dim=1)) self.rec_B_seg = self.merge_masks(torch.cat(self.rec_B_seg_list, dim=1)) def backward_G(self): # 计算生成器的总loss并反向传播 lambda_A = self.opt.lambda_A # 用于backward A lambda_B = self.opt.lambda_B # 用于backward B lambda_idt = self.opt.lambda_idt # 用于loss_idt_A和loss_idt_B lambda_ctx = self.opt.lambda_ctx lambda_fs = self.opt.lambda_fs # backward A if self.forward_A: self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B_mul), True) self.loss_cyc_A = self.criterionCyc(self.rec_A_fuse_sng, self.real_A_fuse_sng) * lambda_A self.fake_A_fuse_sng_idt = self.netG_B(self.real_A_fuse_sng) self.fake_A_img_idt, self.fake_A_seg_idt = self.split(self.fake_A_fuse_sng_idt) self.loss_idt_B = self.criterionIdt(self.fake_A_fuse_sng_idt, self.real_A_fuse_sng.detach()) * lambda_A * lambda_idt weight_A = self.get_weight_for_ctx(self.real_A_seg_sng, self.fake_B_seg_sng) self.loss_ctx_A = self.weighted_L1_loss(self.real_A_img_sng, self.fake_B_img_sng, weight=weight_A) * lambda_A * lambda_ctx layers = {"conv_1_1": 1.0,"conv_3_2": 1.0} I = self.fake_B_img_sng # 生成的B域的图 T = self.real_B_img_sng # 目标域B的真实图 I_multiply = self.fake_B_seg_mul * I T_multiply = self.real_B_seg_sng * T feature_similarity_loss = Feature_Similarity_Loss(layers, max_1d_size=64).cuda() # print('fsloss_A', feature_similarity_loss(I_multiply, T_multiply)) self.loss_fs_A = feature_similarity_loss(I_multiply, T_multiply)[0] * lambda_fs else: self.loss_G_A = 0 self.loss_cyc_A = 0 self.loss_idt_B = 0 self.loss_ctx_A = 0 self.loss_fs_A = 0 # backward B if self.forward_B: self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A_mul), True) self.loss_cyc_B = self.criterionCyc(self.rec_B_fuse_sng, self.real_B_fuse_sng) * lambda_B self.fake_B_fuse_sng_idt = self.netG_A(self.real_B_fuse_sng) self.fake_B_img_idt, self.fake_B_seg_idt = self.split(self.fake_B_fuse_sng_idt) self.loss_idt_A = self.criterionIdt(self.fake_B_fuse_sng_idt, self.real_B_fuse_sng.detach()) * lambda_B * lambda_idt weight_B = self.get_weight_for_ctx(self.real_B_seg_sng, self.fake_A_seg_sng) self.loss_ctx_B = self.weighted_L1_loss(self.real_B_img_sng, self.fake_A_img_sng, weight=weight_B) * lambda_B * lambda_ctx layers = {"conv_1_1": 1.0, "conv_3_2": 1.0} I = self.fake_A_img_sng # 生成的B域的图 T = self.real_A_img_sng # 目标域B的真实图 I_multiply = self.fake_A_seg_mul * I T_multiply = self.real_A_seg_sng * T feature_similarity_loss = Feature_Similarity_Loss(layers, max_1d_size=64).cuda() # print('fsloss_B', feature_similarity_loss(I_multiply, T_multiply)) self.loss_fs_B = feature_similarity_loss(I_multiply, T_multiply)[0] * lambda_fs else: self.loss_G_B = 0 self.loss_cyc_B = 0 self.loss_idt_A = 0 self.loss_ctx_B = 0 self.loss_fs_B = 0 # combined loss # self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cyc_A + self.loss_cyc_B + self.loss_idt_A + self.loss_idt_B + self.loss_ctx_A + self.loss_ctx_B + self.loss_fs_A + self.loss_fs_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cyc_A + self.loss_cyc_B + self.loss_idt_A + self.loss_idt_B + self.loss_fs_A + self.loss_fs_B self.loss_G.backward() # 生成器A和生成器B的各种loss为总G的loss,反向传播 def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B_mul) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A_mul) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def optimize_parameters(self): # 用于train.py,和test()很像 # init setting # 与test()相同的初始化 self.real_A_img_sng = self.real_A_img # self.real_A_img的shape是torch.Size([1, 3, 200, 200]),一张原图,3通道 self.real_B_img_sng = self.real_B_img self.fake_A_seg_list = list() self.fake_B_seg_list = list() self.rec_A_seg_list = list() self.rec_B_seg_list = list() # sequential mini-batch translation for i in range(self.ins_iter): # forward self.forward(i) # G_A and G_B # 比test多出的部分 if self.forward_A or self.forward_B: self.set_requires_grad([self.netD_A, self.netD_B], False) # 为什么设置判别器A和判别器B的参数不需要更新? self.optimizer_G.zero_grad() self.backward_G() # 生成器的loss的反向传播 self.optimizer_G.step() # 更新参数 # D_A and D_B # 比test多出的部分 if self.forward_A or self.forward_B: self.set_requires_grad([self.netD_A, self.netD_B], True) # 设置判别器的参数需要更新 self.optimizer_D.zero_grad() if self.forward_A: self.backward_D_A() # 判别器A的loss的反向传播,为什么判别器要分开反向传播? if self.forward_B: self.backward_D_B() # 判别器B的loss的反向传播 self.optimizer_D.step() # 更新参数 # update setting for next iteration self.real_A_img_sng = self.fake_B_img_sng.detach() self.real_B_img_sng = self.fake_A_img_sng.detach() self.fake_A_seg_list.append(self.fake_A_seg_sng.detach()) self.fake_B_seg_list.append(self.fake_B_seg_sng.detach()) self.rec_A_seg_list.append(self.rec_A_seg_sng.detach()) self.rec_B_seg_list.append(self.rec_B_seg_sng.detach()) # save visuals if i == 0: # first self.rec_A_img = self.rec_A_img_sng self.rec_B_img = self.rec_B_img_sng if i == self.ins_iter - 1: # last self.fake_A_img = self.fake_A_img_sng self.fake_B_img = self.fake_B_img_sng self.fake_A_seg = self.merge_masks(self.fake_A_seg_mul) self.fake_B_seg = self.merge_masks(self.fake_B_seg_mul) self.rec_A_seg = self.merge_masks(torch.cat(self.rec_A_seg_list, dim=1)) self.rec_B_seg = self.merge_masks(torch.cat(self.rec_B_seg_list, dim=1))
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.opt = opt self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) if opt.vgg > 0: self.vgg_loss = networks.PerceptualLoss() self.vgg_loss.cuda() self.vgg = networks.load_vgg16("./model") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) skip = True if opt.skip > 0 else False self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions if opt.use_wgan: self.criterionGAN = networks.DiscLossWGANGP() else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if opt.use_mse: self.criterionCycle = torch.nn.MSELoss() else: self.criterionCycle = torch.nn.L1Loss() self.criterionL1 = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) if opt.isTrain: self.netG_A.train() self.netG_B.train() else: self.netG_A.eval() self.netG_B.eval() print('-----------------------------------------------')
class FlowRefineModel(BaseModel): def name(self): return 'PVHMModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain opt.output_nc = opt.input_nc # load/define networks self.netG = networks.define_G(2, 2, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, tanh=True) self.flow_remapper = networks.flow_remapper(size=opt.fineSize, batch=opt.batchSize,gpu_ids=opt.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) grid = np.zeros((opt.fineSize,opt.fineSize,2)) for i in range(grid.shape[0]): for j in range(grid.shape[1]): grid[i,j,0] = j grid[i,j,1] = i grid /= (opt.fineSize/2) grid -= 1 self.grid = torch.from_numpy(grid).cuda().float() #Variable(torch.from_numpy(grid)) self.grid = self.grid.view(1,self.grid.size(0),self.grid.size(1),self.grid.size(2)) self.grid = Variable(self.grid) intrinsics = np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((1, 3, 3)) intrinsics_inv = np.linalg.inv(np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3)) self.intrinsics = Variable(torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand(opt.batchSize,3,3) self.intrinsics_inv = Variable(torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand(opt.batchSize,3,3) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] input_C = input['C'] if len(self.gpu_ids) > 0: input_C = input_C.cuda(self.gpu_ids[0], async=True) self.input_C = input_C def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) self.real_C = Variable(self.input_C) pose = np.array([0, 0, 0, 0, -np.pi / 4., 0, ]).reshape((1, 6)) pose = Variable(torch.from_numpy(pose.astype(np.float32)).cuda()).expand(self.opt.batchSize,6) self.forward_map = inverse_warp(self.real_A,self.real_C, pose, self.intrinsics, self.intrinsics_inv) self.backward_map = self.flow_remapper(self.forward_map, self.forward_map) self.backward_map_refined = self.netG(self.backward_map.permute(0,3,1,2)).permute(0,2,3,1) self.fake_B = F.grid_sample(self.real_A, self.backward_map_refined) # no backprop gradients def test(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) self.real_C = Variable(self.input_C) pose = np.array([0, 0, 0, 0, -np.pi / 8., 0, ]).reshape((1, 6)) pose = Variable(torch.from_numpy(pose.astype(np.float32)).cuda()).expand(self.opt.batchSize, 6) self.forward_map = inverse_warp(self.real_A, self.real_C, pose, self.intrinsics, self.intrinsics_inv) self.backward_map = self.flow_remapper(self.forward_map, self.forward_map) self.backward_map_refined = self.netG(self.backward_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) self.fake_B = F.grid_sample(self.real_A, self.backward_map_refined) # self.fake_B = self.fake_B_flow # self.fake_B = self.fake_B.permute(0, 3, 1, 2) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.opt.lambda_gan * self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A # self.loss_G_flow = self.criterionL1(self.forward_flow, self.real_C) * self.opt.lambda_flow self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward(retain_graph=True) def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) # real_C = util.tensor2im(self.real_C.data) forward_map = util.tensor2im(self.forward_map.permute(0, 3, 1, 2).data) backward_map = util.tensor2im(self.backward_map.permute(0, 3, 1, 2).data) backward_map_refined = util.tensor2im(self.backward_map_refined.permute(0, 3, 1, 2).data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), \ ('forward_map', forward_map), ('backward_map', backward_map),('backward_map_refined', backward_map_refined),]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load Gs self.model_names = ['G'] # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain opt.output_nc = opt.input_nc # load/define networks self.netG = networks.define_G(2, 2, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, tanh=True) self.flow_remapper = networks.flow_remapper(size=opt.fineSize, batch=opt.batchSize,gpu_ids=opt.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) grid = np.zeros((opt.fineSize,opt.fineSize,2)) for i in range(grid.shape[0]): for j in range(grid.shape[1]): grid[i,j,0] = j grid[i,j,1] = i grid /= (opt.fineSize/2) grid -= 1 self.grid = torch.from_numpy(grid).cuda().float() #Variable(torch.from_numpy(grid)) self.grid = self.grid.view(1,self.grid.size(0),self.grid.size(1),self.grid.size(2)) self.grid = Variable(self.grid) intrinsics = np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((1, 3, 3)) intrinsics_inv = np.linalg.inv(np.array( [128. / 32. * 60, 0., 64., \ 0., 128. / 32. * 60, 64., \ 0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3)) self.intrinsics = Variable(torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand(opt.batchSize,3,3) self.intrinsics_inv = Variable(torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand(opt.batchSize,3,3) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------')
class CycleDRPANModel(BaseModel): def name(self): return 'CycleDRPANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'R_A', 'GR_A'] # specify the images you want to save/display. The program will call base_model.get_current_visuals if self.isTrain: visual_names_A = ['real_A', 'fake_B', 'rec_A', 'fake_Br', 'real_Ar', 'fake_Bf', 'real_Af'] visual_names_B = ['real_B', 'fake_A', 'rec_B', 'fake_Ar', 'real_Br', 'fake_Af', 'real_Bf'] else: visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'R_A', 'R_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X), R_A(R_Y), R_B(R_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netR_A = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netR_B = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_R_A = torch.optim.Adam(self.netR_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_R_B = torch.optim.Adam(self.netR_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_R_A) self.optimizers.append(self.optimizer_R_B) self.proposal = Proposal() # self.batchsize = opt.batchSize # self.label_r = torch.FloatTensor(self.batchsize) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def reviser_A(self): # training with reviser for n_step in range(3): fake_B_ = self.netG_A(self.real_A) output = self.netD_A(fake_B_.detach()) # proposal self.fake_Br, self.real_Ar, self.fake_Bf, self.real_Af, self.fake_ABf, self.real_ABr = self.proposal.forward_A(self.real_B, fake_B_, self.real_A, output) # train with real self.netD_A.zero_grad() output_r = self.netR_A(self.real_ABr.detach()) self.loss_errR_real_A = self.criterionGAN(output_r, True) self.loss_errR_real_A.backward() # train with fake output_r = self.netR_A(self.fake_ABf.detach()) self.loss_errR_fake_A = self.criterionGAN(output_r, False) self.loss_errR_fake_A.backward() self.loss_R_A = (self.loss_errR_real_A + self.loss_errR_fake_A) / 2 self.optimizer_R_A.step() # train Generator with reviser self.netG_A.zero_grad() output_r = self.netR_A(self.fake_ABf) self.loss_GR_A = self.criterionGAN(output_r, True) self.loss_GR_A.backward() self.optimizer_G.step() def reviser_B(self): # training with reviser for n_step in range(3): fake_A_ = self.netG_B(self.real_B) output = self.netD_B(fake_A_.detach()) # proposal self.fake_Ar, self.real_Br, self.fake_Af, self.real_Bf, self.fake_BAf, self.real_BAr = self.proposal.forward_B(self.real_A, fake_A_, self.real_B, output) # train with real self.netD_B.zero_grad() output_r = self.netR_B(self.real_BAr.detach()) self.loss_errR_real_B = self.criterionGAN(output_r, True) self.loss_errR_real_B.backward() # train with fake output_r = self.netR_B(self.fake_BAf.detach()) self.loss_errR_fake_B = self.criterionGAN(output_r, False) self.loss_errR_fake_B.backward() self.loss_R_B = (self.loss_errR_real_B + self.loss_errR_fake_B) / 2 self.optimizer_R_B.step() # train Generator with reviser self.netG_B.zero_grad() output_r = self.netR_B(self.fake_BAf) self.errGAN_r = self.criterionGAN(output_r, True) self.loss_GR_B = self.errGAN_r self.loss_GR_B.backward() self.optimizer_G.step() def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() # R_A and R_B self.set_requires_grad([self.netR_A, self.netR_B], True) self.optimizer_R_A.zero_grad() self.optimizer_R_B.zero_grad() self.reviser_A() self.reviser_B()
class ReCycleGANModel(BaseModel): def name(self): return 'ReCycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) assert 'recycle_skips' in opt.which_model_netG self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, fourier_mode=opt.fourier_mode) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, fourier_mode=opt.fourier_mode) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) networks.print_network(self.netG_B) if self.isTrain: networks.print_network(self.netD_A) networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B, is_cycle=True) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A, is_cycle=True) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)