def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 200, resize = None, norm_img = False, num_workers=0): ''' PreNet: pre-trained CNN images: fake images labels_assi: assigned labels resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W ''' PreNet.eval() # assume images are nxncximg_sizeximg_size n = images.shape[0] nc = images.shape[1] #number of channels img_size = images.shape[2] labels_assi = labels_assi.reshape(-1) eval_trainset = IMGs_dataset(images, labels_assi, normalize=False) eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers) labels_pred = np.zeros(n+batch_size) nimgs_got = 0 pb = SimpleProgressBar() for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader): batch_images = batch_images.type(torch.float).cuda() batch_labels = batch_labels.type(torch.float).cuda() batch_size_curr = len(batch_labels) if norm_img: batch_images = normalize_images(batch_images) batch_labels_pred, _ = PreNet(batch_images) labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1) nimgs_got += batch_size_curr pb.update((float(nimgs_got)/n)*100) del batch_images; gc.collect() torch.cuda.empty_cache() #end for batch_idx labels_pred = labels_pred[0:n] labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift) labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift) ls_mean = np.mean(np.abs(labels_pred-labels_assi)) ls_std = np.std(np.abs(labels_pred-labels_assi)) return ls_mean, ls_std
assert len(labels) == len(images) # define training and validation sets if args.CVMode: #90% Training; 10% valdation valid_prop = 0.1 #proportion of the validation samples indx_all = np.arange(len(images)) np.random.shuffle(indx_all) indx_valid = indx_all[0:int(valid_prop * len(images))] indx_train = indx_all[int(valid_prop * len(images)):] if args.transform: trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True, rotate=True, degrees=[90, 180, 270], hflip=True, vflip=True) else: trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=8) validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True) validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False,
q2 = args.max_label indx = np.where((labels > q1) * (labels < q2) == True)[0] labels = labels[indx] images = images[indx] assert len(labels) == len(images) # define training and validation sets if args.CVMode: #90% Training; 10% valdation valid_prop = 0.1 #proportion of the validation samples indx_all = np.arange(len(images)) np.random.shuffle(indx_all) indx_valid = indx_all[0:int(valid_prop * len(images))] indx_train = indx_all[int(valid_prop * len(images)):] trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=8) validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True) validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=8) else: trainset = IMGs_dataset(images, labels=None, normalize=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True,
def train_cgan_concat(images, labels, netG, netD, save_images_folder, save_models_folder=None): netG = netG.cuda() netD = netD.cuda() optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999)) optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999)) trainset = IMGs_dataset(images, labels, normalize=True) train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int) if save_models_folder is not None and resume_niters > 0: save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format( gan_arch, num_D_steps, resume_niters) checkpoint = torch.load(save_file) netG.load_state_dict(checkpoint['netG_state_dict']) netD.load_state_dict(checkpoint['netD_state_dict']) optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) torch.set_rng_state(checkpoint['rng_state']) #end if # printed images with labels between the 5-th quantile and 95-th quantile of training labels n_row = 10 n_col = n_row z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).cuda() start_label = np.quantile(labels, 0.05) end_label = np.quantile(labels, 0.95) selected_labels = np.linspace(start_label, end_label, num=n_row) y_fixed = np.zeros(n_row * n_col) for i in range(n_row): curr_label = selected_labels[i] for j in range(n_col): y_fixed[i * n_col + j] = curr_label print(y_fixed) y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).cuda() batch_idx = 0 dataloader_iter = iter(train_dataloader) start_time = timeit.default_timer() for niter in range(resume_niters, niters): if batch_idx + 1 == len(train_dataloader): dataloader_iter = iter(train_dataloader) batch_idx = 0 ''' Train Generator: maximize log(D(G(z))) ''' netG.train() # get training images _, batch_train_labels = dataloader_iter.next() assert batch_size == batch_train_labels.shape[0] batch_train_labels = batch_train_labels.type(torch.long).cuda() batch_idx += 1 # Sample noise and labels as generator input z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda() #generate fake images batch_fake_images = netG(z, batch_train_labels) # Loss measures generator's ability to fool the discriminator if use_DiffAugment: dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels) else: dis_out = netD(batch_fake_images, batch_train_labels) if loss_type == "vanilla": dis_out = torch.nn.Sigmoid()(dis_out) g_loss = -torch.mean(torch.log(dis_out + 1e-20)) elif loss_type == "hinge": g_loss = -torch.mean(dis_out) optimizerG.zero_grad() g_loss.backward() optimizerG.step() ''' Train Discriminator: maximize log(D(x)) + log(1 - D(G(z))) ''' for _ in range(num_D_steps): if batch_idx + 1 == len(train_dataloader): dataloader_iter = iter(train_dataloader) batch_idx = 0 # get training images batch_train_images, batch_train_labels = dataloader_iter.next() assert batch_size == batch_train_images.shape[0] batch_train_images = batch_train_images.type(torch.float).cuda() batch_train_labels = batch_train_labels.type(torch.long).cuda() batch_idx += 1 # Measure discriminator's ability to classify real from generated samples if use_DiffAugment: real_dis_out = netD( DiffAugment(batch_train_images, policy=policy), batch_train_labels) fake_dis_out = netD( DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach()) else: real_dis_out = netD(batch_train_images, batch_train_labels) fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach()) if loss_type == "vanilla": real_dis_out = torch.nn.Sigmoid()(real_dis_out) fake_dis_out = torch.nn.Sigmoid()(fake_dis_out) d_loss_real = -torch.log(real_dis_out + 1e-20) d_loss_fake = -torch.log(1 - fake_dis_out + 1e-20) elif loss_type == "hinge": d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out) d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out) d_loss = (d_loss_real + d_loss_fake).mean() optimizerD.zero_grad() d_loss.backward() optimizerD.step() if (niter + 1) % 20 == 0: print( "cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter + 1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer() - start_time)) if (niter + 1) % visualize_freq == 0: netG.eval() with torch.no_grad(): gen_imgs = netG(z_fixed, y_fixed) gen_imgs = gen_imgs.detach() save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter + 1), nrow=n_row, normalize=True) if save_models_folder is not None and ( (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters): save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format( gan_arch, num_D_steps, niter + 1) os.makedirs(os.path.dirname(save_file), exist_ok=True) torch.save( { 'netG_state_dict': netG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), 'rng_state': torch.get_rng_state() }, save_file) #end for niter return netG, netD
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=8) else: h5py_file = wd + '/data/MNIST_reduced_trainset_' + str( args.N_TRAIN) + '.h5' hf = h5py.File(h5py_file, 'r') images_train = hf['images_train'][:] labels_train = hf['labels_train'][:] hf.close() if args.transform: trainset = IMGs_dataset(images_train, labels_train, normalize=True, rotate=True, degrees=15, crop=True, crop_size=28, crop_pad=4) else: trainset = IMGs_dataset(images_train, labels_train, normalize=True, rotate=False, degrees=15, crop=False, crop_size=28, crop_pad=4) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True,
def train_cgan(train_images, train_labels, netG, netD, save_images_folder, save_models_folder=None): netG = netG.cuda() netD = netD.cuda() criterion = nn.BCELoss() optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999)) optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999)) trainset = IMGs_dataset(train_images, train_labels, normalize=True) train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) unique_labels = np.sort(np.array(list(set(train_labels)))).astype(np.int) if save_models_folder is not None and resume_niters > 0: save_file = save_models_folder + "/{}_checkpoint_intrain/{}_checkpoint_niters_{}.pth".format( gan_arch, gan_arch, resume_niters) checkpoint = torch.load(save_file) netG.load_state_dict(checkpoint['netG_state_dict']) netD.load_state_dict(checkpoint['netD_state_dict']) optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) torch.set_rng_state(checkpoint['rng_state']) #end if # printed images with labels between the 5-th quantile and 95-th quantile of training labels n_row = 10 n_col = n_row z_fixed = torch.randn(n_row * n_col, dim_z, dtype=torch.float).cuda() start_label = np.quantile(train_labels, 0.05) end_label = np.quantile(train_labels, 0.95) selected_labels = np.linspace(start_label, end_label, num=n_row) y_fixed = np.zeros(n_row * n_col) for i in range(n_row): curr_label = selected_labels[i] for j in range(n_col): y_fixed[i * n_col + j] = curr_label print(y_fixed) y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).cuda() batch_idx = 0 dataloader_iter = iter(train_dataloader) start_time = timeit.default_timer() for niter in range(resume_niters, niters): if batch_idx + 1 == len(train_dataloader): dataloader_iter = iter(train_dataloader) batch_idx = 0 # training images batch_train_images, batch_train_labels = dataloader_iter.next() assert batch_size == batch_train_images.shape[0] batch_train_images = batch_train_images.type(torch.float).cuda() batch_train_labels = batch_train_labels.type(torch.long).cuda() # Adversarial ground truths GAN_real = torch.ones(batch_size, 1).cuda() GAN_fake = torch.zeros(batch_size, 1).cuda() ''' Train Generator: maximize log(D(G(z))) ''' netG.train() # Sample noise and labels as generator input z = torch.randn(batch_size, dim_z, dtype=torch.float).cuda() #generate fake images batch_fake_images = netG(z, batch_train_labels) # Loss measures generator's ability to fool the discriminator dis_out = netD(batch_fake_images, batch_train_labels) #generator try to let disc believe gen_imgs are real g_loss = criterion(dis_out, GAN_real) optimizerG.zero_grad() g_loss.backward() optimizerG.step() ''' Train Discriminator: maximize log(D(x)) + log(1 - D(G(z))) ''' # Measure discriminator's ability to classify real from generated samples prob_real = netD(batch_train_images, batch_train_labels) prob_fake = netD(batch_fake_images.detach(), batch_train_labels.detach()) real_loss = criterion(prob_real, GAN_real) fake_loss = criterion(prob_fake, GAN_fake) d_loss = (real_loss + fake_loss) / 2 optimizerD.zero_grad() d_loss.backward() optimizerD.step() batch_idx += 1 if (niter + 1) % 20 == 0: print( "%s-concat: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D prob real:%.4f] [D prob fake:%.4f] [Time: %.4f]" % (gan_arch, niter + 1, niters, d_loss.item(), g_loss.item(), prob_real.mean().item(), prob_fake.mean().item(), timeit.default_timer() - start_time)) if (niter + 1) % visualize_freq == 0: netG.eval() with torch.no_grad(): gen_imgs = netG(z_fixed, y_fixed) gen_imgs = gen_imgs.detach() save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter + 1), nrow=n_row, normalize=True) if save_models_folder is not None and ( (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters): save_file = save_models_folder + "/{}_checkpoint_intrain/{}_checkpoint_niters_{}.pth".format( gan_arch, gan_arch, niter + 1) os.makedirs(os.path.dirname(save_file), exist_ok=True) torch.save( { 'netG_state_dict': netG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), 'rng_state': torch.get_rng_state() }, save_file) #end for niter return netG, netD
N_train = len(images_train) N_valid = len(images_valid) assert len(images_train) == len(counts_train) print("Number of images: {}/{}".format(N_train, N_valid)) # noralization is very important here!!!!!!!!! # counts = counts/np.max(counts) counts_train = counts_train / args.end_count counts_valid = counts_valid / args.end_count if args.transform: trainset = IMGs_dataset(images_train, counts_train, normalize=True, rotate=True, degrees=[90, 180, 270], hflip=True, vflip=True) else: trainset = IMGs_dataset(images_train, counts_train, normalize=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True) validset = IMGs_dataset(images_valid, counts_valid, normalize=True) validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False) # model initialization
def train_SNGAN(EPOCHS_GAN, GAN_Latent_Length, trainloader, netG, netD, optimizerG, optimizerD, save_SNGANimages_folder, save_models_folder = None, ResumeEpoch = 0, device="cuda", tfboard_writer=None): netG = netG.to(device) netD = netD.to(device) if save_models_folder is not None and ResumeEpoch>0: print("\r Resume training >>>") save_file = save_models_folder + "/SNGAN_checkpoint_intrain/SNGAN_checkpoint_epoch" + str(ResumeEpoch) + ".pth" checkpoint = torch.load(save_file) netG.load_state_dict(checkpoint['netG_state_dict']) netD.load_state_dict(checkpoint['netD_state_dict']) optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) gen_iterations = checkpoint['gen_iterations'] else: gen_iterations = 0 #end if n_row=10 z_fixed = torch.randn(n_row**2, GAN_Latent_Length, dtype=torch.float).to(device) start_tmp = timeit.default_timer() for epoch in range(ResumeEpoch, EPOCHS_GAN): # adjust_learning_rate(optimizerG, optimizerD, epoch, base_lr_g=1e-4, base_lr_d=4e-4) for batch_idx, (batch_train_images, _) in enumerate(trainloader): BATCH_SIZE = batch_train_images.shape[0] batch_train_images = batch_train_images.to(device) # batch_train_images = batch_train_images.type(torch.float).to(device) ''' Train Discriminator: hinge loss ''' d_out_real,_ = netD(batch_train_images) d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() z = torch.randn(BATCH_SIZE, GAN_Latent_Length, dtype=torch.float).to(device) gen_imgs = netG(z) d_out_fake,_ = netD(gen_imgs.detach()) d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake optimizerD.zero_grad() d_loss.backward() optimizerD.step() ''' Train Generator: hinge loss ''' z = torch.randn(BATCH_SIZE, GAN_Latent_Length, dtype=torch.float).to(device) gen_imgs = netG(z) g_out_fake,_ = netD(gen_imgs) g_loss = - g_out_fake.mean() optimizerG.zero_grad() g_loss.backward() optimizerG.step() gen_iterations += 1 if gen_iterations % N_ITER_IS == 0: with torch.no_grad(): # n_row=10 # z = torch.from_numpy(np.random.normal(0, 1, (n_row**2, GAN_Latent_Length))).type(torch.float).to(device) gen_imgs = netG(z_fixed) gen_imgs = gen_imgs.detach() save_image(gen_imgs.data, save_SNGANimages_folder +'%d.png' % gen_iterations, nrow=n_row, normalize=True) tfboard_writer.add_scalar('D loss', d_loss.item(), gen_iterations) tfboard_writer.add_scalar('G loss', g_loss.item(), gen_iterations) if gen_iterations%20 == 0 and gen_iterations%N_ITER_IS != 0: print ("SNGAN: [Iter %d/%d] [Epoch %d/%d] [D loss: %.4f] [G loss: %.4f] [Time: %.4f]" % (gen_iterations, len(trainloader)*EPOCHS_GAN, epoch+1, EPOCHS_GAN, d_loss.item(), g_loss.item(), timeit.default_timer()-start_tmp)) elif gen_iterations%N_ITER_IS == 0: #compute inception score del gen_imgs, batch_train_images; gc.collect() fake_images = np.zeros((NFAKE_IS_TRAIN+BATCH_SIZE_IS_TRAIN, NC, IMG_SIZE, IMG_SIZE)) netG.eval() with torch.no_grad(): tmp = 0 while tmp < NFAKE_IS_TRAIN: z = torch.randn(BATCH_SIZE_IS_TRAIN, GAN_Latent_Length, dtype=torch.float).to(device) batch_fake_images = netG(z) fake_images[tmp:(tmp+BATCH_SIZE_IS_TRAIN)] = batch_fake_images.cpu().detach().numpy() tmp += BATCH_SIZE_IS_TRAIN fake_images = fake_images[0:NFAKE_IS_TRAIN] del batch_fake_images; gc.collect() (IS_mean, IS_std) = inception_score(IMGs_dataset(fake_images), cuda=True, batch_size=IS_BATCH_SIZE, resize=True, splits=10, ngpu=NGPU) tfboard_writer.add_scalar('Inception Score (mean)', IS_mean, gen_iterations) tfboard_writer.add_scalar('Inception Score (std)', IS_std, gen_iterations) print ("SNGAN: [Iter %d/%d] [Epoch %d/%d] [D loss: %.4f] [G loss: %.4f] [Time: %.4f] [IS: %.3f/%.3f]" % (gen_iterations, len(trainloader)*EPOCHS_GAN, epoch+1, EPOCHS_GAN, d_loss.item(), g_loss.item(), timeit.default_timer()-start_tmp, IS_mean, IS_std)) if save_models_folder is not None and (epoch+1) % 25 == 0: save_file = save_models_folder + "/SNGAN_checkpoint_intrain" os.makedirs(save_file, exist_ok=True) save_file = save_file + "/SNGAN_checkpoint_epoch" + str(epoch+1) + ".pth" torch.save({ 'gen_iterations': gen_iterations, 'netG_state_dict': netG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict() }, save_file) #end for epoch return netG, netD, optimizerG, optimizerD
def inception_score(imgs, num_classes, net, cuda=True, batch_size=32, splits=1, normalize_img=False): """Computes the inception score of the generated images imgs imgs -- unnormalized (3xHxW) numpy images net -- Classification CNN cuda -- whether or not to run on GPU batch_size -- batch size for feeding into Inception v3 splits -- number of splits """ N = len(imgs) assert batch_size > 0 assert N > batch_size # Set up dtype if cuda: dtype = torch.cuda.FloatTensor else: if torch.cuda.is_available(): print( "WARNING: You have a CUDA device, so you should probably set cuda=True" ) dtype = torch.FloatTensor # Set up dataloader dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) # Load inception model if cuda: net = net.cuda() else: net = net.cpu() net.eval() def get_pred(x): x, _ = net(x) return F.softmax(x, dim=1).data.cpu().numpy() # Get predictions preds = np.zeros((N, num_classes)) for i, batch in enumerate(dataloader, 0): batch = batch.type(dtype) batchv = Variable(batch) batch_size_i = batch.size()[0] preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) # Now compute the mean kl-div split_scores = [] for k in range(splits): part = preds[k * (N // splits):(k + 1) * (N // splits), :] py = np.mean(part, axis=0) scores = [] for i in range(part.shape[0]): pyx = part[i, :] scores.append(entropy(pyx, py)) split_scores.append(np.exp(np.mean(scores))) return np.mean(split_scores), np.std(split_scores) # from torchvision.models.inception import inception_v3 # def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1, normalize_img=False): # """Computes the inception score of the generated images imgs based on Inception V3 which is pretrained on ImageNet # imgs -- unnormalized (3xHxW) numpy images # net -- Classification CNN # cuda -- whether or not to run on GPU # batch_size -- batch size for feeding into Inception v3 # splits -- number of splits # """ # N = len(imgs) # assert batch_size > 0 # assert N > batch_size # # Set up dtype # if cuda: # dtype = torch.cuda.FloatTensor # else: # if torch.cuda.is_available(): # print("WARNING: You have a CUDA device, so you should probably set cuda=True") # dtype = torch.FloatTensor # # Set up dataloader # dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img) # dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) # # Load inception model # inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) # inception_model.eval(); # # up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) # def get_pred(x): # if resize: # x = nn.functional.interpolate(x, size = (299, 299), scale_factor=None, mode='bilinear', align_corners=False) # x = inception_model(x) # return F.softmax(x, dim=1).data.cpu().numpy() # # Get predictions # preds = np.zeros((N, 1000)) # for i, batch in enumerate(dataloader, 0): # batch = batch.type(dtype) # batchv = Variable(batch) # batch_size_i = batch.size()[0] # preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) # # Now compute the mean kl-div # split_scores = [] # for k in range(splits): # part = preds[k * (N // splits): (k+1) * (N // splits), :] # py = np.mean(part, axis=0) # scores = [] # for i in range(part.shape[0]): # pyx = part[i, :] # scores.append(entropy(pyx, py)) # split_scores.append(np.exp(np.mean(scores))) # return np.mean(split_scores), np.std(split_scores)
def inception_score(imgs, num_classes, net, cuda=True, batch_size=32, splits=1, normalize_img=False): """Computes the inception score of the generated images imgs imgs -- unnormalized (3xHxW) numpy images net -- Classification CNN cuda -- whether or not to run on GPU batch_size -- batch size for feeding into Inception v3 splits -- number of splits """ N = len(imgs) assert batch_size > 0 assert N > batch_size # Set up dtype if cuda: dtype = torch.cuda.FloatTensor else: if torch.cuda.is_available(): print( "WARNING: You have a CUDA device, so you should probably set cuda=True" ) dtype = torch.FloatTensor # Set up dataloader dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) # Load inception model if cuda: net = net.cuda() else: net = net.cpu() net.eval() def get_pred(x): x, _ = net(x) return F.softmax(x, dim=1).data.cpu().numpy() # Get predictions preds = np.zeros((N, num_classes)) for i, batch in enumerate(dataloader, 0): batch = batch.type(dtype) batchv = Variable(batch) batch_size_i = batch.size()[0] preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) # Now compute the mean kl-div split_scores = [] for k in range(splits): part = preds[k * (N // splits):(k + 1) * (N // splits), :] py = np.mean(part, axis=0) scores = [] for i in range(part.shape[0]): pyx = part[i, :] scores.append(entropy(pyx, py)) split_scores.append(np.exp(np.mean(scores))) return np.mean(split_scores), np.std(split_scores)