Beispiel #1
0
    def discriminator_loss(
            self, batch: tp.Tuple[torch.Tensor,
                                  torch.Tensor]) -> tp.Dict[str, tp.Any]:
        inputs, labels = batch
        labels = labels.float()
        permuted_labels = permute_labels(labels)

        discriminator_on_real_outputs, classification_on_real_outputs = self.discriminator(
            inputs)
        classification_loss = self.classification_loss(
            classification_on_real_outputs, labels)

        with torch.no_grad():
            generator_outputs = self.generator(inputs, permuted_labels)
        discriminator_on_fake_outputs, classification_on_fake_outputs = self.discriminator(
            generator_outputs)

        adversarial_loss = self.adversarial_loss(
            discriminator_on_real_outputs, discriminator_on_fake_outputs)
        gradient_penalty = self.gradient_penalty(inputs, generator_outputs)

        loss = adversarial_loss + \
               self.hparams.lambda_gradient_penalty * gradient_penalty + \
               self.hparams.lambda_classification * classification_loss

        self.log_dict({
            'discriminator adversarial loss': adversarial_loss,
            'discriminator classification loss': classification_loss,
            'discriminator gradient penalty': gradient_penalty,
            'discriminator loss': loss
        })
        return loss
Beispiel #2
0
def calculate_activation_statistics(dataloader, model, classifier, attr):
    classifier.eval()
    model.eval()
    device = model.device

    real_act = []
    gen_act = []

    with torch.no_grad():
        for image, label in dataloader:
            image, label = image.to(device), label.to(device)
            label_src = label[:, attr]
            label_trg = permute_labels(label_src)

            gen = model.generate(image, label_trg)
            image, gen = transform(image), transform(gen)

            real_act.append(classifier(image))
            gen_act.append(classifier(gen))
    real_act = torch.cat(real_act).cpu()
    gen_act = torch.cat(gen_act).cpu()

    mu1, sigma1 = real_act.mean(axis=0), np.cov(real_act, rowvar=False)
    mu2, sigma2 = gen_act.mean(axis=0), np.cov(gen_act, rowvar=False)

    return mu1, sigma1, mu2, sigma2
Beispiel #3
0
    def generator_loss(
            self, batch: tp.Tuple[torch.Tensor,
                                  torch.Tensor]) -> tp.Dict[str, tp.Any]:
        inputs, labels = batch
        labels = labels.float()
        permuted_labels = permute_labels(labels)

        generator_outputs = self.generator(inputs, permuted_labels)
        discriminator_on_fake_outputs, classification_on_fake_outputs = self.discriminator(
            generator_outputs)

        adversarial_loss = self.adversarial_loss(
            on_real_outputs=discriminator_on_fake_outputs)
        classification_loss = self.classification_loss(
            classification_on_fake_outputs, permuted_labels)

        reconstructed_inputs = self.generator(generator_outputs, labels)
        reconstruction_loss = self.reconstruction_loss(reconstructed_inputs,
                                                       inputs)

        loss = adversarial_loss + \
               self.hparams.lambda_reconstruction * reconstruction_loss + \
               self.hparams.lambda_classification * classification_loss

        self.log_dict({
            'generator adversarial loss': adversarial_loss,
            'generator classification loss': classification_loss,
            'generator reconstruction loss': reconstruction_loss,
            'generator loss': loss
        })

        return loss
 def trainG(self, image, label):
     self.optimizerG.zero_grad()
     new_label = permute_labels(label)
     generated = self.G(image, new_label)
     reconstructed = self.G(generated, label)
     src_out, cls_out = self.D(generated)
     loss_dict, loss = self.criterion.generator_loss(
         image, reconstructed, src_out, cls_out, new_label)
     wandb.log(loss_dict)
     loss.backward()
     torch.nn.utils.clip_grad_norm_(self.G.parameters(), 10.0)
     self.optimizerG.step()
     self.optimizerD.zero_grad()
 def trainD(self, image, label):
     self.optimizerD.zero_grad()
     new_label = permute_labels(label)
     generated = self.G(image, new_label).detach()
     src_out_gen, _ = self.D(generated)
     src_out, cls_out = self.D(image)
     gp = compute_gradient_penalty(self.D, generated, image, self.device)
     loss_dict, loss = self.criterion.discriminator_loss(
         image, src_out_gen, src_out, cls_out, label, gp)
     wandb.log(loss_dict)
     loss.backward()
     torch.nn.utils.clip_grad_norm_(self.D.parameters(), 10.0)
     self.optimizerD.step()
     self.optimizerG.zero_grad()
Beispiel #6
0
def calculate_activation_statistics(dataloader, model, classifier, device, ATTRIBUTE_IDX):
    classifier.eval()
    batch_size = dataloader.batch_size
    examples = len(dataloader) * batch_size
    input_acts = np.zeros((examples, CLF_HIDDEN))
    output_acts = np.zeros((examples, CLF_HIDDEN))

    for i, (image, label) in enumerate(tqdm(dataloader, leave=False, desc="fid")):
        input_img = image.to(device)
        label = label[:, ATTRIBUTE_IDX].to(device)
        new_label = permute_labels(label)
        output_img = model.generate(input_img, new_label)
        input_act = classifier(input_img)
        output_act = classifier(output_img)
        input_acts[i * batch_size: (i + 1) * batch_size] = input_act.cpu().numpy()
        output_acts[i * batch_size: (i + 1) * batch_size] = output_act.cpu().numpy()

    mu1, sigma1 = input_acts.mean(axis=0), np.cov(input_acts, rowvar=False)
    mu2, sigma2 = output_acts.mean(axis=0), np.cov(output_acts, rowvar=False)
    return mu1, sigma1, mu2, sigma2
Beispiel #7
0
 def validation_step(self, batch: tp.Tuple[torch.Tensor, torch.Tensor],
                     batch_idx: int):
     images, labels = batch
     permuted_labels = permute_labels(labels).float()
     # desired_labels = self.desired_labels.type_as(images)
     # for label in desired_labels
     generator_outputs = self.generator(images, permuted_labels)
     discriminator_on_real_outputs, classification_on_real_outputs = self.discriminator(
         images)
     self.fid(images, generator_outputs)
     self.accuracy(torch.sigmoid(classification_on_real_outputs), labels)
     # discriminator_on_fake_outputs, classification_on_fake_outputs = self.discriminator(generator_outputs)
     if batch_idx == 0:
         return {
             'real images': images,
             # 'real labels': labels,
         }
     return {
         'real images': None,
         # 'real labels': None
     }