def generate_counterfactual_column(networks, start_images, target_class, **options):
    netG = networks['generator']
    netC = networks['classifier_k']
    netE = networks['encoder']
    speed = options['cf_speed']
    max_iters = options['cf_max_iters']
    distance_weight = options['cf_distance_weight']
    gan_scale = options['cf_gan_scale']
    cf_batch_size = len(start_images)

    loss_class = losses.losses()

    # Start with the latent encodings
    z_value = to_np(netE(start_images, gan_scale))
    z0_value = z_value

    # Move them so their labels match target_label
    target_label = Variable(torch.LongTensor(cf_batch_size)).cuda()
    target_label[:] = target_class

    for i in range(max_iters):
        z = to_torch(z_value, requires_grad=True)
        z_0 = to_torch(z0_value)
        logits = netC(netG(z, gan_scale))
        augmented_logits = F.pad(logits, pad=(0,1))

        # CHANGE
        cf_loss = loss_class.power_loss_05(augmented_logits, target_label)

        distance_loss = torch.sum(
                (
                    z.mean(dim=-1).mean(dim=-1)
                    -
                    z_0.mean(dim=-1).mean(dim=-1)
                ) ** 2
            ) * distance_weight

        total_loss = cf_loss + distance_loss

        scores = augmented_logits

        log.collect('Counterfactual loss', cf_loss)
        log.collect('Distance Loss', distance_loss)
        log.collect('Classification as {}'.format(target_class), scores[0][target_class])
        log.print_every(n_sec=1)

        dc_dz = autograd.grad(total_loss, z, total_loss)[0]
        z = z - dc_dz * speed
        z = clamp_to_unit_sphere(z, gan_scale)

        # TODO: Workaround for Pytorch memory leak
        # Convert back to numpy and destroy the computational graph
        # See https://github.com/pytorch/pytorch/issues/4661
        z_value = to_np(z)
        del z
    print(log)
    z = to_torch(z_value)

    images = netG(z, gan_scale)
    return images.data.cpu().numpy()
Example #2
0
    def forward(self, x, scale=4, output_scale=4):
        batch_size = len(x)

        x = self.features(x)
        x = self.conv(x)
        x = x.view(batch_size, -1)
        x = clamp_to_unit_sphere(x, scale * scale)
        return x
Example #3
0
def generate_images_for_class(networks, dataloader, class_idx, **options):
    netG = networks['generator']
    netD = networks['discriminator']
    result_dir = options['result_dir']
    image_size = options['image_size']
    latent_size = options['latent_size']
    output_frame_count = options['counterfactual_frame_count']
    speed = options['speed']
    momentum_mu = options['momentum_mu']
    max_iters = options['counterfactual_max_iters']
    result_dir = options['result_dir']

    # Start with K random points
    K = dataloader.num_classes
    z = gen_noise(K, latent_size)
    z = Variable(z, requires_grad=True).cuda()

    # Move them so their labels match target_label
    target_label = torch.LongTensor(K)
    target_label[:] = class_idx
    target_label = Variable(target_label).cuda()

    for i in range(max_iters):
        images = netG(z)
        net_y = netD(images)
        preds = softmax(net_y, dim=1)

        pred_classes = to_np(preds.max(1)[1])
        predicted_class = pred_classes[0]
        pred_confidences = to_np(preds.max(1)[0])
        pred_confidence = pred_confidences[0]
        predicted_class_name = dataloader.lab_conv.labels[predicted_class]
        print("Class: {} ({:.3f} confidence). Target class {}".format(
            predicted_class_name, pred_confidence, class_idx))

        cf_loss = nll_loss(log_softmax(net_y, dim=1), target_label)

        dc_dz = autograd.grad(cf_loss, z, cf_loss, retain_graph=True)[0]
        z -= dc_dz * speed
        z = clamp_to_unit_sphere(z)
        if all(pred_classes == class_idx) and all(pred_confidences > 0.75):
            break
    return images.data.cpu().numpy()
Example #4
0
    def forward(self, x, output_scale=1):
        batch_size = len(x)

        x = self.dr1(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.LeakyReLU(0.2)(x)

        x = self.dr2(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv6(x)
        x = self.bn6(x)
        x = nn.LeakyReLU(0.2)(x)

        # Image representation is now 8 x 8
        if output_scale == 8:
            x = self.conv_out_6(x)
            x = x.view(batch_size, -1)
            x = clamp_to_unit_sphere(x, 8 * 8)
            return x

        # x = self.dr3(x)
        # x = self.conv7(x)
        # x = self.bn7(x)
        # x = nn.LeakyReLU(0.2)(x)
        # x = self.conv8(x)
        # x = self.bn8(x)
        # x = nn.LeakyReLU(0.2)(x)
        # x = self.conv9(x)
        # x = self.bn9(x)
        # x = nn.LeakyReLU(0.2)(x)
        x = self.layers(x)

        # Image representation is now 4x4
        if output_scale == 4:
            x = self.conv_out_9(x)
            x = x.view(batch_size, -1)
            x = clamp_to_unit_sphere(x, 4 * 4)
            return x

        x = self.dr4(x)
        x = self.conv10(x)
        x = self.bn10(x)
        x = nn.LeakyReLU(0.2)(x)

        # Image representation is now 2x2
        if output_scale == 2:
            x = self.conv_out_10(x)
            x = x.view(batch_size, -1)
            x = clamp_to_unit_sphere(x, 2 * 2)
            return x

        x = x.view(batch_size, -1)
        x = self.fc1(x)
        x = clamp_to_unit_sphere(x)
        return x
Example #5
0
def train_gan(networks, optimizers, dataloader, epoch=None, **options):
    for net in networks.values():
        net.train()
    netD = networks['discriminator']
    netG = networks['generator']
    optimizerD = optimizers['discriminator']
    optimizerG = optimizers['generator']
    result_dir = options['result_dir']
    batch_size = options['batch_size']
    image_size = options['image_size']
    latent_size = options['latent_size']
    discriminator_per_gen = options['discriminator_per_gen']

    fixed_noise = Variable(torch.FloatTensor(batch_size, latent_size).normal_(0, 1)).cuda()
    fixed_noise = clamp_to_unit_sphere(fixed_noise)

    start_time = time.time()
    correct = 0
    total = 0

    for i, (images, class_labels) in enumerate(dataloader):
        images = Variable(images)
        labels = Variable(class_labels)

        ############################
        # Generator Updates
        ############################
        netG.zero_grad()
        z = gen_noise(batch_size, latent_size)
        z = Variable(z).cuda()
        gen_images = netG(z)
        
        # Feature Matching: Average of one batch of real vs. generated
        features_real = netD(images, return_features=True)
        features_gen = netD(gen_images, return_features=True)
        fm_loss = torch.mean((features_real.mean(0) - features_gen.mean(0)) ** 2)

        # Pull-away term from https://github.com/kimiyoung/ssl_bad_gan
        nsample = features_gen.size(0)
        denom = features_gen.norm(dim=0).expand_as(features_gen)
        gen_feat_norm = features_gen / denom
        cosine = torch.mm(features_gen, features_gen.t())
        mask = Variable((torch.ones(cosine.size()) - torch.diag(torch.ones(nsample))).cuda())
        pt_loss = torch.sum((cosine * mask) ** 2) / (nsample * (nsample + 1))
        pt_loss /= (128 * 128)

        errG = fm_loss + pt_loss

        # Classify generated examples as "not fake"
        gen_logits = netD(gen_images)
        augmented_logits = F.pad(-gen_logits, pad=(0,1))
        log_prob_gen = F.log_softmax(augmented_logits, dim=1)[:, -1]
        errG += -log_prob_gen.mean()

        errG.backward()
        optimizerG.step()
        ###########################

        ############################
        # Discriminator Updates
        ###########################
        netD.zero_grad()

        # Classify generated examples as "fake" (ie the K+1th "open" class)
        z = gen_noise(batch_size, latent_size)
        z = Variable(z).cuda()
        fake_images = netG(z).detach()
        fake_logits = netD(fake_images)
        augmented_logits = F.pad(fake_logits, pad=(0,1))
        log_prob_fake = F.log_softmax(augmented_logits, dim=1)[:, -1]
        errD = -log_prob_fake.mean()
        errD.backward()

        # Classify real examples into the correct K classes
        real_logits = netD(images)
        positive_labels = (labels == 1).type(torch.cuda.FloatTensor)
        augmented_logits = F.pad(real_logits, pad=(0,1))
        augmented_labels = F.pad(positive_labels, pad=(0,1))
        log_prob_real = F.log_softmax(augmented_logits, dim=1) * augmented_labels
        #log_prob_real = F.log_softmax(augmented_logits, dim=1)[:, 0]
        errC = -log_prob_real.mean()
        errC.backward()

        optimizerD.step()
        ############################

        # Keep track of accuracy on positive-labeled examples for monitoring
        _, pred_idx = real_logits.max(1)
        _, label_idx = labels.max(1)
        correct += sum(pred_idx == label_idx).data.cpu().numpy()[0]
        total += len(labels)

        if i % 100 == 0:
            demo_fakes = netG(fixed_noise)
            img = torch.cat([demo_fakes.data[:36]])
            filename = "{}/demo_{}.jpg".format(result_dir, int(time.time()))
            imutil.show(img, filename=filename, resize_to=(512,512))

            bps = (i+1) / (time.time() - start_time)
            ed = errD.data[0]
            eg = errG.data[0]
            ec = errC.data[0]
            acc = correct / max(total, 1)
            msg = '[{}][{}/{}] D:{:.3f} G:{:.3f} C:{:.3f} Acc. {:.3f} {:.3f} batch/sec'
            msg = msg.format(
                  epoch, i+1, len(dataloader),
                  ed, eg, ec, acc, bps)
            print(msg)
            print("log_prob_real {:.3f}".format(log_prob_real.mean().data[0]))
            print("log_prob_fake {:.3f}".format(log_prob_fake.mean().data[0]))
            print("log_prob_gen {:.3f}".format(log_prob_gen.mean().data[0]))
            print("pt_loss {:.3f}".format(pt_loss.data[0]))
            print("fm_loss {:.3f}".format(fm_loss.data[0]))
            print("Accuracy {}/{}".format(correct, total))
    return True