def generate_counterfactual_column(networks, start_images, target_class, **options): netG = networks['generator'] netC = networks['classifier_k'] netE = networks['encoder'] speed = options['cf_speed'] max_iters = options['cf_max_iters'] distance_weight = options['cf_distance_weight'] gan_scale = options['cf_gan_scale'] cf_batch_size = len(start_images) loss_class = losses.losses() # Start with the latent encodings z_value = to_np(netE(start_images, gan_scale)) z0_value = z_value # Move them so their labels match target_label target_label = Variable(torch.LongTensor(cf_batch_size)).cuda() target_label[:] = target_class for i in range(max_iters): z = to_torch(z_value, requires_grad=True) z_0 = to_torch(z0_value) logits = netC(netG(z, gan_scale)) augmented_logits = F.pad(logits, pad=(0,1)) # CHANGE cf_loss = loss_class.power_loss_05(augmented_logits, target_label) distance_loss = torch.sum( ( z.mean(dim=-1).mean(dim=-1) - z_0.mean(dim=-1).mean(dim=-1) ) ** 2 ) * distance_weight total_loss = cf_loss + distance_loss scores = augmented_logits log.collect('Counterfactual loss', cf_loss) log.collect('Distance Loss', distance_loss) log.collect('Classification as {}'.format(target_class), scores[0][target_class]) log.print_every(n_sec=1) dc_dz = autograd.grad(total_loss, z, total_loss)[0] z = z - dc_dz * speed z = clamp_to_unit_sphere(z, gan_scale) # TODO: Workaround for Pytorch memory leak # Convert back to numpy and destroy the computational graph # See https://github.com/pytorch/pytorch/issues/4661 z_value = to_np(z) del z print(log) z = to_torch(z_value) images = netG(z, gan_scale) return images.data.cpu().numpy()
def forward(self, x, scale=4, output_scale=4): batch_size = len(x) x = self.features(x) x = self.conv(x) x = x.view(batch_size, -1) x = clamp_to_unit_sphere(x, scale * scale) return x
def generate_images_for_class(networks, dataloader, class_idx, **options): netG = networks['generator'] netD = networks['discriminator'] result_dir = options['result_dir'] image_size = options['image_size'] latent_size = options['latent_size'] output_frame_count = options['counterfactual_frame_count'] speed = options['speed'] momentum_mu = options['momentum_mu'] max_iters = options['counterfactual_max_iters'] result_dir = options['result_dir'] # Start with K random points K = dataloader.num_classes z = gen_noise(K, latent_size) z = Variable(z, requires_grad=True).cuda() # Move them so their labels match target_label target_label = torch.LongTensor(K) target_label[:] = class_idx target_label = Variable(target_label).cuda() for i in range(max_iters): images = netG(z) net_y = netD(images) preds = softmax(net_y, dim=1) pred_classes = to_np(preds.max(1)[1]) predicted_class = pred_classes[0] pred_confidences = to_np(preds.max(1)[0]) pred_confidence = pred_confidences[0] predicted_class_name = dataloader.lab_conv.labels[predicted_class] print("Class: {} ({:.3f} confidence). Target class {}".format( predicted_class_name, pred_confidence, class_idx)) cf_loss = nll_loss(log_softmax(net_y, dim=1), target_label) dc_dz = autograd.grad(cf_loss, z, cf_loss, retain_graph=True)[0] z -= dc_dz * speed z = clamp_to_unit_sphere(z) if all(pred_classes == class_idx) and all(pred_confidences > 0.75): break return images.data.cpu().numpy()
def forward(self, x, output_scale=1): batch_size = len(x) x = self.dr1(x) x = self.conv1(x) x = self.bn1(x) x = nn.LeakyReLU(0.2)(x) x = self.conv2(x) x = self.bn2(x) x = nn.LeakyReLU(0.2)(x) x = self.conv3(x) x = self.bn3(x) x = nn.LeakyReLU(0.2)(x) x = self.dr2(x) x = self.conv4(x) x = self.bn4(x) x = nn.LeakyReLU(0.2)(x) x = self.conv5(x) x = self.bn5(x) x = nn.LeakyReLU(0.2)(x) x = self.conv6(x) x = self.bn6(x) x = nn.LeakyReLU(0.2)(x) # Image representation is now 8 x 8 if output_scale == 8: x = self.conv_out_6(x) x = x.view(batch_size, -1) x = clamp_to_unit_sphere(x, 8 * 8) return x # x = self.dr3(x) # x = self.conv7(x) # x = self.bn7(x) # x = nn.LeakyReLU(0.2)(x) # x = self.conv8(x) # x = self.bn8(x) # x = nn.LeakyReLU(0.2)(x) # x = self.conv9(x) # x = self.bn9(x) # x = nn.LeakyReLU(0.2)(x) x = self.layers(x) # Image representation is now 4x4 if output_scale == 4: x = self.conv_out_9(x) x = x.view(batch_size, -1) x = clamp_to_unit_sphere(x, 4 * 4) return x x = self.dr4(x) x = self.conv10(x) x = self.bn10(x) x = nn.LeakyReLU(0.2)(x) # Image representation is now 2x2 if output_scale == 2: x = self.conv_out_10(x) x = x.view(batch_size, -1) x = clamp_to_unit_sphere(x, 2 * 2) return x x = x.view(batch_size, -1) x = self.fc1(x) x = clamp_to_unit_sphere(x) return x
def train_gan(networks, optimizers, dataloader, epoch=None, **options): for net in networks.values(): net.train() netD = networks['discriminator'] netG = networks['generator'] optimizerD = optimizers['discriminator'] optimizerG = optimizers['generator'] result_dir = options['result_dir'] batch_size = options['batch_size'] image_size = options['image_size'] latent_size = options['latent_size'] discriminator_per_gen = options['discriminator_per_gen'] fixed_noise = Variable(torch.FloatTensor(batch_size, latent_size).normal_(0, 1)).cuda() fixed_noise = clamp_to_unit_sphere(fixed_noise) start_time = time.time() correct = 0 total = 0 for i, (images, class_labels) in enumerate(dataloader): images = Variable(images) labels = Variable(class_labels) ############################ # Generator Updates ############################ netG.zero_grad() z = gen_noise(batch_size, latent_size) z = Variable(z).cuda() gen_images = netG(z) # Feature Matching: Average of one batch of real vs. generated features_real = netD(images, return_features=True) features_gen = netD(gen_images, return_features=True) fm_loss = torch.mean((features_real.mean(0) - features_gen.mean(0)) ** 2) # Pull-away term from https://github.com/kimiyoung/ssl_bad_gan nsample = features_gen.size(0) denom = features_gen.norm(dim=0).expand_as(features_gen) gen_feat_norm = features_gen / denom cosine = torch.mm(features_gen, features_gen.t()) mask = Variable((torch.ones(cosine.size()) - torch.diag(torch.ones(nsample))).cuda()) pt_loss = torch.sum((cosine * mask) ** 2) / (nsample * (nsample + 1)) pt_loss /= (128 * 128) errG = fm_loss + pt_loss # Classify generated examples as "not fake" gen_logits = netD(gen_images) augmented_logits = F.pad(-gen_logits, pad=(0,1)) log_prob_gen = F.log_softmax(augmented_logits, dim=1)[:, -1] errG += -log_prob_gen.mean() errG.backward() optimizerG.step() ########################### ############################ # Discriminator Updates ########################### netD.zero_grad() # Classify generated examples as "fake" (ie the K+1th "open" class) z = gen_noise(batch_size, latent_size) z = Variable(z).cuda() fake_images = netG(z).detach() fake_logits = netD(fake_images) augmented_logits = F.pad(fake_logits, pad=(0,1)) log_prob_fake = F.log_softmax(augmented_logits, dim=1)[:, -1] errD = -log_prob_fake.mean() errD.backward() # Classify real examples into the correct K classes real_logits = netD(images) positive_labels = (labels == 1).type(torch.cuda.FloatTensor) augmented_logits = F.pad(real_logits, pad=(0,1)) augmented_labels = F.pad(positive_labels, pad=(0,1)) log_prob_real = F.log_softmax(augmented_logits, dim=1) * augmented_labels #log_prob_real = F.log_softmax(augmented_logits, dim=1)[:, 0] errC = -log_prob_real.mean() errC.backward() optimizerD.step() ############################ # Keep track of accuracy on positive-labeled examples for monitoring _, pred_idx = real_logits.max(1) _, label_idx = labels.max(1) correct += sum(pred_idx == label_idx).data.cpu().numpy()[0] total += len(labels) if i % 100 == 0: demo_fakes = netG(fixed_noise) img = torch.cat([demo_fakes.data[:36]]) filename = "{}/demo_{}.jpg".format(result_dir, int(time.time())) imutil.show(img, filename=filename, resize_to=(512,512)) bps = (i+1) / (time.time() - start_time) ed = errD.data[0] eg = errG.data[0] ec = errC.data[0] acc = correct / max(total, 1) msg = '[{}][{}/{}] D:{:.3f} G:{:.3f} C:{:.3f} Acc. {:.3f} {:.3f} batch/sec' msg = msg.format( epoch, i+1, len(dataloader), ed, eg, ec, acc, bps) print(msg) print("log_prob_real {:.3f}".format(log_prob_real.mean().data[0])) print("log_prob_fake {:.3f}".format(log_prob_fake.mean().data[0])) print("log_prob_gen {:.3f}".format(log_prob_gen.mean().data[0])) print("pt_loss {:.3f}".format(pt_loss.data[0])) print("fm_loss {:.3f}".format(fm_loss.data[0])) print("Accuracy {}/{}".format(correct, total)) return True