def optimizeG(self, allLosses): batch_size = self.real_input.size(0) # Update the generator self.optimizerG.zero_grad() self.optimizerD.zero_grad() # #1 Image generation inputLatent, targetCatNoise = self.buildNoiseData(batch_size, self.realLabels, skipAtts=True) # pdb.set_trace() if getattr(self.config, 'style_mixing', False): inputLatent2, targetRandCat2 = self.buildNoiseData(batch_size, self.realLabels, skipAtts=True) predFakeG = self.netG([inputLatent, inputLatent2]) else: predFakeG = self.netG(inputLatent) # #2 Status evaluation predFakeD, phiGFake = self.netD(predFakeG, True) # #2 Classification criterion if self.config.ac_gan: G_classif_fake = \ self.classificationPenalty(predFakeD, targetCatNoise, self.config.weightConditionG, backward=True, skipAtts=True) allLosses["lossG_classif"] = G_classif_fake # #3 GAN criterion lossGFake = self.lossCriterion.getCriterion(predFakeD, True) allLosses["lossG_fake"] = lossGFake.item() allLosses[ "Spread_R-F"] = allLosses["lossD_real"] - allLosses["lossG_fake"] # Back-propagate generator losss lossGFake.backward() finiteCheck(self.getOriginalG().parameters()) self.register_G_grads() self.optimizerG.step() lossG = 0 for key, val in allLosses.items(): if key.find("lossG") == 0: lossG += val allLosses["lossG"] = lossG # Update the moving average if relevant if isinstance(self.avgG, nn.DataParallel): avgGparams = self.avgG.module.parameters() else: avgGparams = self.avgG.parameters() for p, avg_p in zip(self.getOriginalG().parameters(), avgGparams): avg_p.mul_(0.999).add_(0.001, p.data) return allLosses
def optimizeParameters(self, input_batch, inputLabels=None, **args): r""" Update the discrimator D using the given "real" inputs. Args: input (torch.tensor): input batch of real data inputLabels (torch.tensor): labels of the real data """ allLosses = {} try: allLosses['alpha'] = self.getOriginalD().alpha except AttributeError: pass # Retrieve the input data self.real_input, self.realLabels = input_batch.to(self.device), None if self.config.attribKeysOrder is not None: self.realLabels = inputLabels.to(self.device) n_samples = self.real_input.size()[0] # Update the discriminator self.optimizerD.zero_grad() # #1 Real data predRealD = self.netD(self.real_input, False) # Classification criterion allLosses["lossD_classif"] = \ self.classificationPenalty(predRealD, self.realLabels, self.config.weightConditionD, backward=True) lossD = self.lossCriterion.getCriterion(predRealD, True) allLosses["lossD_real"] = lossD.item() # #2 Fake data inputLatent, targetRandCat = self.buildNoiseData( n_samples, self.realLabels) predFakeG = self.netG(inputLatent).detach() predFakeD = self.netD(predFakeG, False) lossDFake = self.lossCriterion.getCriterion(predFakeD, False) allLosses["lossD_fake"] = lossDFake.item() lossD += lossDFake # #3 WGANGP gradient loss if self.config.lambdaGP > 0: allLosses["lossD_Grad"], allLosses["lipschitz_norm"] =\ WGANGPGradientPenalty(input=self.real_input, fake=predFakeG, discriminator=self.netD, weight=self.config.lambdaGP, backward=True) # #4 Epsilon loss if self.config.epsilonD > 0: lossEpsilon = (predRealD[:, 0]**2).sum() * self.config.epsilonD lossD += lossEpsilon allLosses["lossD_Epsilon"] = lossEpsilon.item() lossD.backward(retain_graph=True) # finiteCheck(self.netD.module.parameters()) finiteCheck(self.netD.parameters()) self.optimizerD.step() # Logs lossD = 0 for key, val in allLosses.items(): if key.find("lossD") == 0: lossD += val allLosses["lossD"] = lossD # Update the generator self.optimizerG.zero_grad() self.optimizerD.zero_grad() # #1 Image generation inputNoise, targetCatNoise = self.buildNoiseData( n_samples, self.realLabels) predFakeG = self.netG(inputNoise) # #2 Status evaluation predFakeD, phiGFake = self.netD(predFakeG, True) # #2 Classification criterion allLosses["lossG_classif"] = \ self.classificationPenalty(predFakeD, targetCatNoise, self.config.weightConditionG, backward=True) # #3 GAN criterion lossGFake = self.lossCriterion.getCriterion(predFakeD, True) allLosses["lossG_fake"] = lossGFake.item() lossGFake.backward() finiteCheck(self.getOriginalG().parameters()) self.optimizerG.step() lossG = 0 for key, val in allLosses.items(): if key.find("lossG") == 0: lossG += val allLosses["lossG"] = lossG # Update the moving average if relevant if isinstance(self.avgG, nn.DataParallel): avgGparams = self.avgG.module.parameters() else: avgGparams = self.avgG.parameters() for p, avg_p in zip(self.getOriginalG().parameters(), avgGparams): avg_p.mul_(0.999).add_(0.001, p.data) return allLosses
def optimizeD(self, allLosses): batch_size = self.real_input.size(0) inputLatent1, targetRandCat1 = self.buildNoiseData(batch_size, self.realLabels, skipAtts=True) if getattr(self.config, 'style_mixing', False): inputLatent2, targetRandCat2 = self.buildNoiseData(batch_size, self.realLabels, skipAtts=True) predFakeG = self.netG([inputLatent1, inputLatent2]).detach() else: predFakeG = self.netG(inputLatent1).detach() self.optimizerD.zero_grad() if self.mix_true_fake: input_batch = self.mix_true_fake_batch(self.real_input, predFakeG, self.true_fake_split) # #1 Real data predRealD = self.netD(self.real_input, False) predFakeD, D_fake_latent = self.netD(predFakeG, True) # CLASSIFICATION LOSS if self.config.ac_gan: # Classification criterion for True and Fake data allLosses["lossD_classif"] = \ self.classificationPenalty(predRealD, self.realLabels, self.config.weightConditionD, backward=True) # + \ # self.classificationPenalty(predFakeD, # self.realLabels, # self.config.weightConditionD * 0.5, # backward=True) # OBJECTIVE FUNCTION FOR TRUE AND FAKE DATA lossD = self.lossCriterion.getCriterion(predRealD, True) allLosses["lossD_real"] = lossD.item() lossDFake = self.lossCriterion.getCriterion(predFakeD, False) allLosses["lossD_fake"] = lossDFake.item() lossD += lossDFake #pdb.set_trace() # #3 WGAN Gradient Penalty loss if self.config.lambdaGP > 0: allLosses["lossD_GP"], allLosses["lipschitz_norm"] = \ WGANGPGradientPenalty(input=self.real_input, fake=predFakeG, discriminator=self.netD, weight=self.config.lambdaGP, backward=True) # #4 Epsilon loss if self.config.epsilonD > 0: lossEpsilon = (predRealD[:, -1]**2).sum() * self.config.epsilonD lossD += lossEpsilon allLosses["lossD_Epsilon"] = lossEpsilon.item() # # 5 Logistic gradient loss if self.config.logisticGradReal > 0: allLosses["lossD_logistic"] = \ logisticGradientPenalty(self.real_input, self.netD, self.config.logisticGradReal, backward=True) lossD.backward() # self.register_D_grads() # finiteCheck(self.netD.module.parameters()) finiteCheck(self.netD.parameters()) self.optimizerD.step() # Logs lossD = 0 for key, val in allLosses.items(): if key.find("lossD") == 0: lossD += val allLosses["lossD"] = lossD return allLosses