示例#1
0
    def optimizeG(self, allLosses):
        batch_size = self.real_input.size(0)
        # Update the generator
        self.optimizerG.zero_grad()
        self.optimizerD.zero_grad()

        # #1 Image generation
        inputLatent, targetCatNoise = self.buildNoiseData(batch_size,
                                                          self.realLabels,
                                                          skipAtts=True)
        #        pdb.set_trace()

        if getattr(self.config, 'style_mixing', False):
            inputLatent2, targetRandCat2 = self.buildNoiseData(batch_size,
                                                               self.realLabels,
                                                               skipAtts=True)
            predFakeG = self.netG([inputLatent, inputLatent2])
        else:
            predFakeG = self.netG(inputLatent)

        # #2 Status evaluation
        predFakeD, phiGFake = self.netD(predFakeG, True)

        # #2 Classification criterion
        if self.config.ac_gan:
            G_classif_fake = \
                self.classificationPenalty(predFakeD,
                                           targetCatNoise,
                                           self.config.weightConditionG,
                                           backward=True,
                                           skipAtts=True)
            allLosses["lossG_classif"] = G_classif_fake
        # #3 GAN criterion
        lossGFake = self.lossCriterion.getCriterion(predFakeD, True)
        allLosses["lossG_fake"] = lossGFake.item()
        allLosses[
            "Spread_R-F"] = allLosses["lossD_real"] - allLosses["lossG_fake"]

        # Back-propagate generator losss
        lossGFake.backward()
        finiteCheck(self.getOriginalG().parameters())
        self.register_G_grads()
        self.optimizerG.step()

        lossG = 0
        for key, val in allLosses.items():

            if key.find("lossG") == 0:
                lossG += val

        allLosses["lossG"] = lossG

        # Update the moving average if relevant
        if isinstance(self.avgG, nn.DataParallel):
            avgGparams = self.avgG.module.parameters()
        else:
            avgGparams = self.avgG.parameters()

        for p, avg_p in zip(self.getOriginalG().parameters(), avgGparams):
            avg_p.mul_(0.999).add_(0.001, p.data)

        return allLosses
示例#2
0
    def optimizeParameters(self, input_batch, inputLabels=None, **args):
        r"""
        Update the discrimator D using the given "real" inputs.

        Args:
            input (torch.tensor): input batch of real data
            inputLabels (torch.tensor): labels of the real data

        """
        allLosses = {}
        try:
            allLosses['alpha'] = self.getOriginalD().alpha
        except AttributeError:
            pass
        # Retrieve the input data
        self.real_input, self.realLabels = input_batch.to(self.device), None
        if self.config.attribKeysOrder is not None:
            self.realLabels = inputLabels.to(self.device)

        n_samples = self.real_input.size()[0]

        # Update the discriminator
        self.optimizerD.zero_grad()

        # #1 Real data
        predRealD = self.netD(self.real_input, False)
        # Classification criterion
        allLosses["lossD_classif"] = \
            self.classificationPenalty(predRealD,
                                       self.realLabels,
                                       self.config.weightConditionD,
                                       backward=True)

        lossD = self.lossCriterion.getCriterion(predRealD, True)
        allLosses["lossD_real"] = lossD.item()

        # #2 Fake data
        inputLatent, targetRandCat = self.buildNoiseData(
            n_samples, self.realLabels)

        predFakeG = self.netG(inputLatent).detach()
        predFakeD = self.netD(predFakeG, False)
        lossDFake = self.lossCriterion.getCriterion(predFakeD, False)
        allLosses["lossD_fake"] = lossDFake.item()
        lossD += lossDFake

        # #3 WGANGP gradient loss
        if self.config.lambdaGP > 0:
            allLosses["lossD_Grad"], allLosses["lipschitz_norm"] =\
                WGANGPGradientPenalty(input=self.real_input,
                                        fake=predFakeG,
                                        discriminator=self.netD,
                                        weight=self.config.lambdaGP,
                                        backward=True)

        # #4 Epsilon loss
        if self.config.epsilonD > 0:
            lossEpsilon = (predRealD[:, 0]**2).sum() * self.config.epsilonD
            lossD += lossEpsilon
            allLosses["lossD_Epsilon"] = lossEpsilon.item()

        lossD.backward(retain_graph=True)
        # finiteCheck(self.netD.module.parameters())
        finiteCheck(self.netD.parameters())
        self.optimizerD.step()

        # Logs
        lossD = 0
        for key, val in allLosses.items():

            if key.find("lossD") == 0:
                lossD += val

        allLosses["lossD"] = lossD

        # Update the generator
        self.optimizerG.zero_grad()
        self.optimizerD.zero_grad()

        # #1 Image generation
        inputNoise, targetCatNoise = self.buildNoiseData(
            n_samples, self.realLabels)
        predFakeG = self.netG(inputNoise)

        # #2 Status evaluation
        predFakeD, phiGFake = self.netD(predFakeG, True)

        # #2 Classification criterion
        allLosses["lossG_classif"] = \
            self.classificationPenalty(predFakeD,
                                       targetCatNoise,
                                       self.config.weightConditionG,
                                       backward=True)

        # #3 GAN criterion
        lossGFake = self.lossCriterion.getCriterion(predFakeD, True)
        allLosses["lossG_fake"] = lossGFake.item()
        lossGFake.backward()

        finiteCheck(self.getOriginalG().parameters())
        self.optimizerG.step()

        lossG = 0
        for key, val in allLosses.items():

            if key.find("lossG") == 0:
                lossG += val

        allLosses["lossG"] = lossG

        # Update the moving average if relevant
        if isinstance(self.avgG, nn.DataParallel):
            avgGparams = self.avgG.module.parameters()
        else:
            avgGparams = self.avgG.parameters()

        for p, avg_p in zip(self.getOriginalG().parameters(), avgGparams):
            avg_p.mul_(0.999).add_(0.001, p.data)

        return allLosses
示例#3
0
    def optimizeD(self, allLosses):
        batch_size = self.real_input.size(0)

        inputLatent1, targetRandCat1 = self.buildNoiseData(batch_size,
                                                           self.realLabels,
                                                           skipAtts=True)
        if getattr(self.config, 'style_mixing', False):
            inputLatent2, targetRandCat2 = self.buildNoiseData(batch_size,
                                                               self.realLabels,
                                                               skipAtts=True)
            predFakeG = self.netG([inputLatent1, inputLatent2]).detach()
        else:
            predFakeG = self.netG(inputLatent1).detach()

        self.optimizerD.zero_grad()

        if self.mix_true_fake:
            input_batch = self.mix_true_fake_batch(self.real_input, predFakeG,
                                                   self.true_fake_split)
        # #1 Real data
        predRealD = self.netD(self.real_input, False)
        predFakeD, D_fake_latent = self.netD(predFakeG, True)

        # CLASSIFICATION LOSS
        if self.config.ac_gan:
            # Classification criterion for True and Fake data
            allLosses["lossD_classif"] = \
                self.classificationPenalty(predRealD,
                                           self.realLabels,
                                           self.config.weightConditionD,
                                           backward=True)
            #                             + \
            # self.classificationPenalty(predFakeD,
            #                            self.realLabels,
            #                            self.config.weightConditionD * 0.5,
            #                            backward=True)

        # OBJECTIVE FUNCTION FOR TRUE AND FAKE DATA
        lossD = self.lossCriterion.getCriterion(predRealD, True)
        allLosses["lossD_real"] = lossD.item()

        lossDFake = self.lossCriterion.getCriterion(predFakeD, False)
        allLosses["lossD_fake"] = lossDFake.item()
        lossD += lossDFake
        #pdb.set_trace()

        # #3 WGAN Gradient Penalty loss
        if self.config.lambdaGP > 0:
            allLosses["lossD_GP"], allLosses["lipschitz_norm"] = \
                WGANGPGradientPenalty(input=self.real_input,
                                        fake=predFakeG,
                                        discriminator=self.netD,
                                        weight=self.config.lambdaGP,
                                        backward=True)

        # #4 Epsilon loss
        if self.config.epsilonD > 0:
            lossEpsilon = (predRealD[:, -1]**2).sum() * self.config.epsilonD
            lossD += lossEpsilon
            allLosses["lossD_Epsilon"] = lossEpsilon.item()

        # # 5 Logistic gradient loss
        if self.config.logisticGradReal > 0:
            allLosses["lossD_logistic"] = \
                logisticGradientPenalty(self.real_input, self.netD,
                                        self.config.logisticGradReal,
                                        backward=True)
        lossD.backward()

        # self.register_D_grads()
        # finiteCheck(self.netD.module.parameters())
        finiteCheck(self.netD.parameters())
        self.optimizerD.step()

        # Logs
        lossD = 0
        for key, val in allLosses.items():

            if key.find("lossD") == 0:
                lossD += val

        allLosses["lossD"] = lossD

        return allLosses