def backward_D_B(self): if self.opt.eval_to_dis: set_eval(self.netG_B) self.fake_A = self.netG_B.forward(self.real_B) self.netG_B.train() fake_A = self.fake_A_pool.query(self.fake_A) if self.opt.lambda_gp > 0: self.loss_D_B_real, self.loss_D_B_fake, self.loss_D_B_gp = self.backward_D_basic(self.netD_B, self.real_A, fake_A) else: self.loss_D_B_real, self.loss_D_B_fake = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B_gp = 0.0
def backward_D_A(self): if self.opt.eval_to_dis: set_eval(self.netG_A) self.fake_B = self.netG_A.forward(self.real_A) self.netG_A.train() fake_B = self.fake_B_pool.query(self.fake_B) # fake_B_sr = self.fake_B_pool.query(self.fake_B_sr) if self.opt.lambda_gp > 0: self.loss_D_A_real, self.loss_D_A_fake, self.loss_D_A_gp = self.backward_D_basic(self.netD_A, self.real_B, fake_B) else: self.loss_D_A_real, self.loss_D_A_fake = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A_gp = 0.0
def test_img(img, net, trans, size, eval_mode=True, bn_eval=True, drop_eval=True): img = trans(img) #input_ = img.view(-1, 3, size, size) input_ = img.unsqueeze(0) real = Variable(input_.cuda()) util.set_eval(net, bn_eval, drop_eval) with torch.no_grad(): fake = net.forward(real) return fake
def backward_D_C(self): if self.opt.eval_to_dis: set_eval(self.netG_B) self.fake_A = self.netG_B.forward(self.real_B) self.netG_B.train() self.real_fea = self.netE_C(self.real_C) self.fake_fea = self.netE_A(self.real_A) # fake_A = self.fake_A_pool.query(self.fake_A) if self.opt.lambda_gp > 0: self.loss_D_C_real, self.loss_D_C_fake, self.loss_D_C_gp = self.backward_D_basic(self.netD_C, self.real_fea, self.fake_fea) self.loss_D_C_real *=5 self.loss_D_C_fake *=5 self.loss_D_C_gp *=5 else: self.loss_D_C_real, self.loss_D_C_fake = 5*self.backward_D_basic(self.netD_C, self.real_fea, self.fake_fea) self.loss_D_C_gp = 0.0
def backward_D_B(self): if self.opt.eval_to_dis: set_eval(self.netG_B) self.fake_A = self.netG_B.forward(self.real_B) self.netG_B.train() fake_A = self.fake_A_pool.query(self.fake_A)
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.target_weight = [] self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) if opt.aux: self.A_aux = self.Tensor(nb, opt.input_nc, size, size) self.B_aux = self.Tensor(nb, opt.output_nc, size, size) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) set_eval(self.netG_A_running) accumulate(self.netG_A_running, self.netG_A, 0) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) set_eval(self.netG_B_running) accumulate(self.netG_B_running, self.netG_B, 0) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) print('---------- Networks initialized -------------') networks.print_network(self.netG_A, opt, (opt.input_nc, opt.fineSize, opt.fineSize)) # networks.print_network(self.netG_B, opt) if self.isTrain: networks.print_network(self.netD_A, opt) # networks.print_network(self.netD_B, opt) print('-----------------------------------------------') if not self.isTrain or opt.continue_train: print('Loaded model') which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netG_A_running, 'G_A', which_epoch) self.load_network(self.netG_B_running, 'G_B', which_epoch) self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain and opt.load_path != '': print('Loaded model from load_path') which_epoch = opt.which_epoch load_network_with_path(self.netG_A, 'G_A', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netG_B, 'G_B', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netD_A, 'D_A', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netD_B, 'D_B', opt.load_path, epoch_label=which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions if len(self.target_weight) == opt.num_D: print(self.target_weight) self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor, target_weight=self.target_weight, gan=opt.gan) else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gan=opt.gan) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionColor = networks.ColorLoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt))