Esempio n. 1
0
    def forward(self, x, reparameterize=False):
        # gpu_ids = None
        # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
        gpu_ids = self.gpu_ids

        x = nn.parallel.data_parallel(self.main, x, gpu_ids)
        x = x.view(x.size()[0], 1024 * int(self.fcsize * 1 * 1))

        xOut = list()

        if self.nClasses > 0:
            xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids)
            xOut.append(xClasses)

        if self.nRef > 0:
            xRefMu = nn.parallel.data_parallel(self.refOutMu, x, gpu_ids)
            xOut.append(xRefMu)

        if self.nLatentDim > 0:
            xLatentMu = nn.parallel.data_parallel(self.latentOutMu, x, gpu_ids)
            xLatentLogSigma = nn.parallel.data_parallel(
                self.latentOutLogSigma, x, gpu_ids)

            if self.training:
                xOut.append([xLatentMu, xLatentLogSigma])
            else:
                xOut.append(
                    bvae.reparameterize(xLatentMu,
                                        xLatentLogSigma,
                                        add_noise=False))

        return xOut
Esempio n. 2
0
    def forward(self, x):
        # gpu_ids = None
        # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
        gpu_ids = self.gpu_ids

        x_tmp = nn.parallel.data_parallel(self.first, x, gpu_ids)

        for pool, trans in zip(self.poolBlocks, self.transitionBlocks):
            x_sub = pool(x)
            x_tmp = nn.parallel.data_parallel(trans,
                                              torch.cat([x_sub, x_tmp], 1),
                                              gpu_ids)

        x = x_tmp.view(x_tmp.size()[0], 1024 * int(self.fcsize * 1 * 1))

        xOut = list()

        if self.nClasses > 0:
            xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids)
            xOut.append(xClasses)

        if self.nRef > 0:
            xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids)
            xOut.append(xRef)

        if self.nLatentDim > 0:
            xLatentMu = nn.parallel.data_parallel(self.latentOutMu, x, gpu_ids)
            xLatentLogSigma = nn.parallel.data_parallel(
                self.latentOutLogSigma, x, gpu_ids)

            if self.training:
                xOut.append([xLatentMu, xLatentLogSigma])
            else:
                xOut.append(
                    bvae.reparameterize(xLatentMu,
                                        xLatentLogSigma,
                                        add_noise=False))

        return xOut
Esempio n. 3
0
    def iteration(self):
        gpu_id = self.gpu_ids[0]

        enc, dec, decD = self.enc, self.dec, self.decD
        optEnc, optDec, optDecD = self.optEnc, self.optDec, self.optDecD
        critRecon, critZClass, critZRef, critDecD = self.critRecon, self.critZClass, self.critZRef, self.critDecD

        #do this just incase anything upstream changes these values
        enc.train(True)
        dec.train(True)
        decD.train(True)

        ###update the discriminator
        #maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x)))
        x, classes, ref = self.data_provider.get_sample()

        self.x.data.copy_(x)
        x = self.x

        y_xFake = self.y_xFake
        y_zReal = self.y_zReal

        if self.n_classes == 0:
            y_xReal = self.y_xReal
            y_zFake = self.y_zFake
        else:
            classes = classes.type_as(x).long()

            y_xReal = classes
            y_zFake = classes

        if self.n_ref > 0:
            ref = ref.type_as(x)

        for p in decD.parameters():
            p.requires_grad = True

        for p in enc.parameters():
            p.requires_grad = False

        for p in dec.parameters():
            p.requires_grad = False

        zAll = enc(x)

        for i in range(len(zAll) - 1):
            zAll[i].detach_()

        for var in zAll[-1]:
            var.detach_()

        zAll[-1] = bvae.reparameterize(zAll[-1][0], zAll[-1][1])

        xHat = dec(zAll)

        self.zReal.data.normal_()
        zReal = self.zReal
        # zReal = Variable(opt.latentSample(self.batch_size, self.n_latent_dim).cuda(gpu_id))
        zFake = zAll[-1]

        optEnc.zero_grad()
        optDec.zero_grad()
        optDecD.zero_grad()

        ##############
        ### Train decD
        ##############

        yHat_xReal = decD(x)

        ### train with real
        errDecD_real = critDecD(yHat_xReal, y_xReal)

        ### train with fake, reconstructed
        yHat_xFake = decD(xHat)
        errDecD_fake = critDecD(yHat_xFake, y_xFake)

        ### train with fake, sampled and decoded
        zAll[-1] = zReal

        yHat_xFake2 = decD(dec(zAll))
        errDecD_fake2 = critDecD(yHat_xFake2, y_xFake)

        decDLoss = (errDecD_real + (errDecD_fake + errDecD_fake2) / 2) / 2
        decDLoss.backward(retain_graph=True)
        optDecD.step()

        decDLoss = decDLoss.data[0]

        errDecD_real = None
        errDecD_fake = None
        errDecD_fake2 = None

        for p in enc.parameters():
            p.requires_grad = True

        for p in dec.parameters():
            p.requires_grad = True

        for p in decD.parameters():
            p.requires_grad = False

        optEnc.zero_grad()
        optDec.zero_grad()
        optDecD.zero_grad()

        #####################
        ### train autoencoder
        #####################

        ### Forward passes
        zAll = enc(x)

        c = 0
        ### Update the class discriminator
        if self.n_classes > 0:
            classLoss = critZClass(zAll[c], classes) * self.lambda_class_loss
            classLoss.backward(retain_graph=True)
            classLoss = classLoss.data[0]

            if self.provide_decoder_vars:
                zAll[c] = torch.log(
                    utils.index_to_onehot(classes, self.n_classes) + 1E-8)

            c += 1

        ### Update the reference shape discriminator
        if self.n_ref > 0:
            refLoss = critZRef(zAll[c], ref) * self.lambda_ref_loss
            refLoss.backward(retain_graph=True)
            refLoss = refLoss.data[0]

            if self.provide_decoder_vars:
                zAll[c] = ref

            c += 1

        total_kld, dimension_wise_kld, mean_kld = bvae.kl_divergence(
            zAll[c][0], zAll[c][1])

        zLatent = zAll[c][0].data.cpu()

        zAll[c] = bvae.reparameterize(zAll[c][0], zAll[c][1])

        xHat = dec(zAll)

        ### Update the image reconstruction
        recon_loss = critRecon(xHat, x)

        if self.objective == 'H':
            beta_vae_loss = recon_loss + self.beta * total_kld
        elif self.objective == 'B':
            C = torch.clamp(
                torch.Tensor([
                    self.c_max / self.c_iters_max * len(self.logger)
                ]).type_as(x), 0, self.c_max)
            beta_vae_loss = recon_loss + self.gamma * (total_kld - C).abs()

        beta_vae_loss.backward(retain_graph=True)
        kld_loss = total_kld.data[0]

        recon_loss = recon_loss.data[0]

        optEnc.step()

        for p in enc.parameters():
            p.requires_grad = False

        ### update wrt decD(dec(enc(X)))
        yHat_xFake = decD(xHat)
        minimaxDecDLoss = critDecD(yHat_xFake, y_xReal)
        (minimaxDecDLoss.mul(
            self.lambda_decD_loss).div(2)).backward(retain_graph=True)
        minimaxDecDLoss = minimaxDecDLoss.data[0]
        yHat_xFake = None

        ### update wrt decD(dec(Z))

        c = 0
        #if we have classes, create random classes, generate images of random classes
        if self.n_classes > 0:
            shuffle_inds = np.arange(0, zAll[0].size(0))

            classes_one_hot = (
                (utils.index_to_onehot(classes, self.n_classes) - 1) *
                25).type_as(zAll[c].data).cuda(self.gpu_ids[0])

            np.random.shuffle(shuffle_inds)
            zAll[c] = classes_one_hot[shuffle_inds, :]
            y_xReal = y_xReal[torch.LongTensor(shuffle_inds).cuda(
                self.gpu_ids[0])]

            c += 1

        if self.n_ref > 0:
            zAll[c].data.normal_()

        #sample random positions in the localization space
        self.zReal.data.normal_()
        zAll[-1] = self.zReal

        xHat = dec(zAll)

        yHat_xFake2 = decD(xHat)
        minimaxDecDLoss2 = critDecD(yHat_xFake2, y_xReal)
        (minimaxDecDLoss2.mul(
            self.lambda_decD_loss).div(2)).backward(retain_graph=True)
        minimaxDecDLoss2 = minimaxDecDLoss2.data[0]
        yHat_xFake2 = None

        minimaxDecLoss = (minimaxDecDLoss + minimaxDecDLoss2) / 2

        optDec.step()

        errors = (recon_loss, )
        if self.n_classes > 0:
            errors += (classLoss, )

        if self.n_ref > 0:
            errors += (refLoss, )

        errors += (kld_loss, minimaxDecLoss, decDLoss)
        errors = [error.cpu() for error in errors]

        return errors, zLatent
Esempio n. 4
0
    def iteration(self):
        gpu_id = self.gpu_ids[0]

        enc, dec = self.enc, self.dec
        optEnc, optDec = self.optEnc, self.optDec
        critRecon, critZClass, critZRef = self.critRecon, self.critZClass, self.critZRef
        
        #do this just incase anything upstream changes these values
        enc.train(True)
        dec.train(True)

        ###update the discriminator
        #maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x)))
        x, classes, ref = self.data_provider.get_sample()
        
        self.x.data.copy_(x)
        x = self.x

        y_xFake = self.y_xFake
        y_zReal = self.y_zReal


        
        if self.n_classes == 0:
            y_xReal = self.y_xReal
            y_zFake = self.y_zFake
        else:
            classes = classes.type_as(x).long()

            y_xReal = classes
            y_zFake = classes

        if self.n_ref > 0:
            ref = ref.type_as(x)

        for p in enc.parameters():
            p.requires_grad = True

        for p in dec.parameters():
            p.requires_grad = True

            
            
        optEnc.zero_grad()
        optDec.zero_grad()

        #####################
        ### train autoencoder
        #####################

        ### Forward passes
        zAll, activations = enc(x)

        c = 0
        ### Update the class discriminator
        if self.n_classes > 0:
            classLoss = critZClass(zAll[c], classes)*self.lambda_class_loss
            classLoss.backward(retain_graph=True)
            classLoss = classLoss.data[0]
            
            if self.provide_decoder_vars:
                zAll[c] = torch.log(utils.index_to_onehot(classes, self.n_classes) + 1E-8)
            
            c += 1
            
        ### Update the reference shape discriminator
        if self.n_ref > 0:
            refLoss = critZRef(zAll[c], ref)*self.lambda_ref_loss
            refLoss.backward(retain_graph=True)
            refLoss = refLoss.data[0]
            
            if self.provide_decoder_vars:
                zAll[c] = ref

            c += 1

        total_kld, dimension_wise_kld, mean_kld = bvae.kl_divergence(zAll[c][0], zAll[c][1])
        
        zLatent = zAll[c][0].data.cpu()

        zAll[c] = bvae.reparameterize(zAll[c][0], zAll[c][1])

        xHat = dec(zAll)

        ### Update the image reconstruction
        recon_loss = critRecon(xHat, x)
        
        if self.objective == 'H':
            beta_vae_loss = recon_loss + self.beta*total_kld
        elif self.objective == 'B':
            C = torch.clamp(torch.Tensor([self.c_max/self.c_iters_max*len(self.logger)]).type_as(x), 0, self.c_max)
            beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()

        beta_vae_loss.backward(retain_graph=True)
        
        for p in enc.parameters(): 
            p.requires_grad = False
            
        _, activations_hat = enc(xHat)

        self_loss = torch.tensor(0).type_as(x)
        for activation_hat, activation in zip(activations_hat, activations):          
            self_loss += critRecon(activation_hat, activation.detach())

        (self_loss*self.lambda_loss).backward()
    
        for p in enc.parameters(): 
            p.requires_grad = False
    
        optEnc.step()
        optDec.step()
        
        kld_loss = total_kld.item()
        recon_loss = recon_loss.item()
        self_loss = self_loss.item()
        
        errors = [recon_loss,]
        if self.n_classes > 0:
            errors += [classLoss,]

        if self.n_ref > 0:
            errors += [refLoss,]

        errors += [kld_loss, self_loss]
        
        return errors, zLatent