def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 200, resize = None, norm_img = False, num_workers=0):
    '''
    PreNet: pre-trained CNN
    images: fake images
    labels_assi: assigned labels
    resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
    '''

    PreNet.eval()

    # assume images are nxncximg_sizeximg_size
    n = images.shape[0]
    nc = images.shape[1] #number of channels
    img_size = images.shape[2]
    labels_assi = labels_assi.reshape(-1)

    eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
    eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    labels_pred = np.zeros(n+batch_size)

    nimgs_got = 0
    pb = SimpleProgressBar()
    for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
        batch_images = batch_images.type(torch.float).cuda()
        batch_labels = batch_labels.type(torch.float).cuda()
        batch_size_curr = len(batch_labels)

        if norm_img:
            batch_images = normalize_images(batch_images)

        batch_labels_pred, _ = PreNet(batch_images)
        labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)

        nimgs_got += batch_size_curr
        pb.update((float(nimgs_got)/n)*100)

        del batch_images; gc.collect()
        torch.cuda.empty_cache()
    #end for batch_idx

    labels_pred = labels_pred[0:n]


    labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
    labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)

    ls_mean = np.mean(np.abs(labels_pred-labels_assi))
    ls_std = np.std(np.abs(labels_pred-labels_assi))

    return ls_mean, ls_std
Beispiel #2
0
assert len(labels) == len(images)

# define training and validation sets
if args.CVMode:
    #90% Training; 10% valdation
    valid_prop = 0.1  #proportion of the validation samples
    indx_all = np.arange(len(images))
    np.random.shuffle(indx_all)
    indx_valid = indx_all[0:int(valid_prop * len(images))]
    indx_train = indx_all[int(valid_prop * len(images)):]

    if args.transform:
        trainset = IMGs_dataset(images[indx_train],
                                labels=None,
                                normalize=True,
                                rotate=True,
                                degrees=[90, 180, 270],
                                hflip=True,
                                vflip=True)
    else:
        trainset = IMGs_dataset(images[indx_train],
                                labels=None,
                                normalize=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size_train,
                                              shuffle=True,
                                              num_workers=8)
    validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
    validloader = torch.utils.data.DataLoader(validset,
                                              batch_size=args.batch_size_valid,
                                              shuffle=False,
q2 = args.max_label
indx = np.where((labels > q1) * (labels < q2) == True)[0]
labels = labels[indx]
images = images[indx]
assert len(labels) == len(images)

# define training and validation sets
if args.CVMode:
    #90% Training; 10% valdation
    valid_prop = 0.1  #proportion of the validation samples
    indx_all = np.arange(len(images))
    np.random.shuffle(indx_all)
    indx_valid = indx_all[0:int(valid_prop * len(images))]
    indx_train = indx_all[int(valid_prop * len(images)):]

    trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size_train,
                                              shuffle=True,
                                              num_workers=8)
    validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
    validloader = torch.utils.data.DataLoader(validset,
                                              batch_size=args.batch_size_valid,
                                              shuffle=False,
                                              num_workers=8)

else:
    trainset = IMGs_dataset(images, labels=None, normalize=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size_train,
                                              shuffle=True,
def train_cgan_concat(images,
                      labels,
                      netG,
                      netD,
                      save_images_folder,
                      save_models_folder=None):

    netG = netG.cuda()
    netD = netD.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=lr_g,
                                  betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=lr_d,
                                  betas=(0.5, 0.999))

    trainset = IMGs_dataset(images, labels, normalize=True)
    train_dataloader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)

    if save_models_folder is not None and resume_niters > 0:
        save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(
            gan_arch, num_D_steps, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row = 10
    n_col = n_row
    z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).cuda()
    start_label = np.quantile(labels, 0.05)
    end_label = np.quantile(labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).cuda()

    batch_idx = 0
    dataloader_iter = iter(train_dataloader)

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):

        if batch_idx + 1 == len(train_dataloader):
            dataloader_iter = iter(train_dataloader)
            batch_idx = 0
        '''

        Train Generator: maximize log(D(G(z)))

        '''

        netG.train()

        # get training images
        _, batch_train_labels = dataloader_iter.next()
        assert batch_size == batch_train_labels.shape[0]
        batch_train_labels = batch_train_labels.type(torch.long).cuda()
        batch_idx += 1

        # Sample noise and labels as generator input
        z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()

        #generate fake images
        batch_fake_images = netG(z, batch_train_labels)

        # Loss measures generator's ability to fool the discriminator
        if use_DiffAugment:
            dis_out = netD(DiffAugment(batch_fake_images, policy=policy),
                           batch_train_labels)
        else:
            dis_out = netD(batch_fake_images, batch_train_labels)

        if loss_type == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = -torch.mean(torch.log(dis_out + 1e-20))
        elif loss_type == "hinge":
            g_loss = -torch.mean(dis_out)

        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()
        '''

        Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))

        '''

        for _ in range(num_D_steps):

            if batch_idx + 1 == len(train_dataloader):
                dataloader_iter = iter(train_dataloader)
                batch_idx = 0

            # get training images
            batch_train_images, batch_train_labels = dataloader_iter.next()
            assert batch_size == batch_train_images.shape[0]
            batch_train_images = batch_train_images.type(torch.float).cuda()
            batch_train_labels = batch_train_labels.type(torch.long).cuda()
            batch_idx += 1

            # Measure discriminator's ability to classify real from generated samples
            if use_DiffAugment:
                real_dis_out = netD(
                    DiffAugment(batch_train_images, policy=policy),
                    batch_train_labels)
                fake_dis_out = netD(
                    DiffAugment(batch_fake_images.detach(), policy=policy),
                    batch_train_labels.detach())
            else:
                real_dis_out = netD(batch_train_images, batch_train_labels)
                fake_dis_out = netD(batch_fake_images.detach(),
                                    batch_train_labels.detach())

            if loss_type == "vanilla":
                real_dis_out = torch.nn.Sigmoid()(real_dis_out)
                fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
                d_loss_real = -torch.log(real_dis_out + 1e-20)
                d_loss_fake = -torch.log(1 - fake_dis_out + 1e-20)
            elif loss_type == "hinge":
                d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
                d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
            d_loss = (d_loss_real + d_loss_fake).mean()

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

        if (niter + 1) % 20 == 0:
            print(
                "cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]"
                % (gan_arch, niter + 1, niters, d_loss.item(), g_loss.item(),
                   real_dis_out.mean().item(), fake_dis_out.mean().item(),
                   timeit.default_timer() - start_time))

        if (niter + 1) % visualize_freq == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, y_fixed)
                gen_imgs = gen_imgs.detach()
            save_image(gen_imgs.data,
                       save_images_folder + '/{}.png'.format(niter + 1),
                       nrow=n_row,
                       normalize=True)

        if save_models_folder is not None and (
            (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters):
            save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(
                gan_arch, num_D_steps, niter + 1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for niter

    return netG, netD
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size_train,
                                              shuffle=True,
                                              num_workers=8)
else:
    h5py_file = wd + '/data/MNIST_reduced_trainset_' + str(
        args.N_TRAIN) + '.h5'
    hf = h5py.File(h5py_file, 'r')
    images_train = hf['images_train'][:]
    labels_train = hf['labels_train'][:]
    hf.close()
    if args.transform:
        trainset = IMGs_dataset(images_train,
                                labels_train,
                                normalize=True,
                                rotate=True,
                                degrees=15,
                                crop=True,
                                crop_size=28,
                                crop_pad=4)
    else:
        trainset = IMGs_dataset(images_train,
                                labels_train,
                                normalize=True,
                                rotate=False,
                                degrees=15,
                                crop=False,
                                crop_size=28,
                                crop_pad=4)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size_train,
                                              shuffle=True,
Beispiel #6
0
def train_cgan(train_images,
               train_labels,
               netG,
               netD,
               save_images_folder,
               save_models_folder=None):

    netG = netG.cuda()
    netD = netD.cuda()

    criterion = nn.BCELoss()
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=lr_g,
                                  betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=lr_d,
                                  betas=(0.5, 0.999))

    trainset = IMGs_dataset(train_images, train_labels, normalize=True)
    train_dataloader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    unique_labels = np.sort(np.array(list(set(train_labels)))).astype(np.int)

    if save_models_folder is not None and resume_niters > 0:
        save_file = save_models_folder + "/{}_checkpoint_intrain/{}_checkpoint_niters_{}.pth".format(
            gan_arch, gan_arch, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row = 10
    n_col = n_row
    z_fixed = torch.randn(n_row * n_col, dim_z, dtype=torch.float).cuda()
    start_label = np.quantile(train_labels, 0.05)
    end_label = np.quantile(train_labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).cuda()

    batch_idx = 0
    dataloader_iter = iter(train_dataloader)

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):

        if batch_idx + 1 == len(train_dataloader):
            dataloader_iter = iter(train_dataloader)
            batch_idx = 0

        # training images
        batch_train_images, batch_train_labels = dataloader_iter.next()
        assert batch_size == batch_train_images.shape[0]
        batch_train_images = batch_train_images.type(torch.float).cuda()
        batch_train_labels = batch_train_labels.type(torch.long).cuda()

        # Adversarial ground truths
        GAN_real = torch.ones(batch_size, 1).cuda()
        GAN_fake = torch.zeros(batch_size, 1).cuda()
        '''

        Train Generator: maximize log(D(G(z)))

        '''
        netG.train()

        # Sample noise and labels as generator input
        z = torch.randn(batch_size, dim_z, dtype=torch.float).cuda()

        #generate fake images
        batch_fake_images = netG(z, batch_train_labels)

        # Loss measures generator's ability to fool the discriminator
        dis_out = netD(batch_fake_images, batch_train_labels)

        #generator try to let disc believe gen_imgs are real
        g_loss = criterion(dis_out, GAN_real)

        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()
        '''

        Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))

        '''

        # Measure discriminator's ability to classify real from generated samples
        prob_real = netD(batch_train_images, batch_train_labels)
        prob_fake = netD(batch_fake_images.detach(),
                         batch_train_labels.detach())
        real_loss = criterion(prob_real, GAN_real)
        fake_loss = criterion(prob_fake, GAN_fake)
        d_loss = (real_loss + fake_loss) / 2

        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()

        batch_idx += 1

        if (niter + 1) % 20 == 0:
            print(
                "%s-concat: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D prob real:%.4f] [D prob fake:%.4f] [Time: %.4f]"
                % (gan_arch, niter + 1, niters, d_loss.item(), g_loss.item(),
                   prob_real.mean().item(), prob_fake.mean().item(),
                   timeit.default_timer() - start_time))

        if (niter + 1) % visualize_freq == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, y_fixed)
                gen_imgs = gen_imgs.detach()
            save_image(gen_imgs.data,
                       save_images_folder + '/{}.png'.format(niter + 1),
                       nrow=n_row,
                       normalize=True)

        if save_models_folder is not None and (
            (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters):
            save_file = save_models_folder + "/{}_checkpoint_intrain/{}_checkpoint_niters_{}.pth".format(
                gan_arch, gan_arch, niter + 1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for niter

    return netG, netD
Beispiel #7
0
N_train = len(images_train)
N_valid = len(images_valid)
assert len(images_train) == len(counts_train)

print("Number of images: {}/{}".format(N_train, N_valid))

# noralization is very important here!!!!!!!!!
# counts = counts/np.max(counts)
counts_train = counts_train / args.end_count
counts_valid = counts_valid / args.end_count

if args.transform:
    trainset = IMGs_dataset(images_train,
                            counts_train,
                            normalize=True,
                            rotate=True,
                            degrees=[90, 180, 270],
                            hflip=True,
                            vflip=True)
else:
    trainset = IMGs_dataset(images_train, counts_train, normalize=True)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.batch_size_train,
                                          shuffle=True)

validset = IMGs_dataset(images_valid, counts_valid, normalize=True)
validloader = torch.utils.data.DataLoader(validset,
                                          batch_size=args.batch_size_valid,
                                          shuffle=False)

# model initialization
Beispiel #8
0
def train_SNGAN(EPOCHS_GAN, GAN_Latent_Length, trainloader, netG, netD, optimizerG, optimizerD, save_SNGANimages_folder, save_models_folder = None, ResumeEpoch = 0, device="cuda", tfboard_writer=None):


    netG = netG.to(device)
    netD = netD.to(device)

    if save_models_folder is not None and ResumeEpoch>0:
        print("\r Resume training >>>")
        save_file = save_models_folder + "/SNGAN_checkpoint_intrain/SNGAN_checkpoint_epoch" + str(ResumeEpoch) + ".pth"
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        gen_iterations = checkpoint['gen_iterations']
    else:
        gen_iterations = 0
    #end if

    n_row=10
    z_fixed = torch.randn(n_row**2, GAN_Latent_Length, dtype=torch.float).to(device)

    start_tmp = timeit.default_timer()
    for epoch in range(ResumeEpoch, EPOCHS_GAN):
        # adjust_learning_rate(optimizerG, optimizerD, epoch, base_lr_g=1e-4, base_lr_d=4e-4)
        for batch_idx, (batch_train_images, _) in enumerate(trainloader):

            BATCH_SIZE = batch_train_images.shape[0]
            batch_train_images = batch_train_images.to(device)
            # batch_train_images = batch_train_images.type(torch.float).to(device)

            '''

            Train Discriminator: hinge loss

            '''
            d_out_real,_ = netD(batch_train_images)
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            z = torch.randn(BATCH_SIZE, GAN_Latent_Length, dtype=torch.float).to(device)
            gen_imgs = netG(z)
            d_out_fake,_ = netD(gen_imgs.detach())
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

            '''

            Train Generator: hinge loss

            '''
            z = torch.randn(BATCH_SIZE, GAN_Latent_Length, dtype=torch.float).to(device)
            gen_imgs = netG(z)
            g_out_fake,_ = netD(gen_imgs)

            g_loss = - g_out_fake.mean()
            optimizerG.zero_grad()
            g_loss.backward()
            optimizerG.step()

            gen_iterations += 1

            if gen_iterations % N_ITER_IS == 0:
                with torch.no_grad():
                    # n_row=10
                    # z = torch.from_numpy(np.random.normal(0, 1, (n_row**2, GAN_Latent_Length))).type(torch.float).to(device)
                    gen_imgs = netG(z_fixed)
                    gen_imgs = gen_imgs.detach()
                save_image(gen_imgs.data, save_SNGANimages_folder +'%d.png' % gen_iterations, nrow=n_row, normalize=True)


            tfboard_writer.add_scalar('D loss', d_loss.item(), gen_iterations)
            tfboard_writer.add_scalar('G loss', g_loss.item(), gen_iterations)

            if gen_iterations%20 == 0 and gen_iterations%N_ITER_IS != 0:
                print ("SNGAN: [Iter %d/%d] [Epoch %d/%d] [D loss: %.4f] [G loss: %.4f] [Time: %.4f]" % (gen_iterations, len(trainloader)*EPOCHS_GAN, epoch+1, EPOCHS_GAN, d_loss.item(), g_loss.item(), timeit.default_timer()-start_tmp))
            elif gen_iterations%N_ITER_IS == 0: #compute inception score
                del gen_imgs, batch_train_images; gc.collect()
                fake_images = np.zeros((NFAKE_IS_TRAIN+BATCH_SIZE_IS_TRAIN, NC, IMG_SIZE, IMG_SIZE))
                netG.eval()
                with torch.no_grad():
                    tmp = 0
                    while tmp < NFAKE_IS_TRAIN:
                        z = torch.randn(BATCH_SIZE_IS_TRAIN, GAN_Latent_Length, dtype=torch.float).to(device)
                        batch_fake_images = netG(z)
                        fake_images[tmp:(tmp+BATCH_SIZE_IS_TRAIN)] = batch_fake_images.cpu().detach().numpy()
                        tmp += BATCH_SIZE_IS_TRAIN
                fake_images = fake_images[0:NFAKE_IS_TRAIN]
                del batch_fake_images; gc.collect()
                (IS_mean, IS_std) = inception_score(IMGs_dataset(fake_images), cuda=True, batch_size=IS_BATCH_SIZE, resize=True, splits=10, ngpu=NGPU)

                tfboard_writer.add_scalar('Inception Score (mean)', IS_mean, gen_iterations)
                tfboard_writer.add_scalar('Inception Score (std)', IS_std, gen_iterations)

                print ("SNGAN: [Iter %d/%d] [Epoch %d/%d] [D loss: %.4f] [G loss: %.4f] [Time: %.4f] [IS: %.3f/%.3f]" % (gen_iterations, len(trainloader)*EPOCHS_GAN, epoch+1, EPOCHS_GAN, d_loss.item(), g_loss.item(), timeit.default_timer()-start_tmp, IS_mean, IS_std))

        if save_models_folder is not None and (epoch+1) % 25 == 0:
            save_file = save_models_folder + "/SNGAN_checkpoint_intrain"
            os.makedirs(save_file, exist_ok=True)
            save_file = save_file + "/SNGAN_checkpoint_epoch" + str(epoch+1) + ".pth"
            torch.save({
                    'gen_iterations': gen_iterations,
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict()
            }, save_file)
    #end for epoch

    return netG, netD, optimizerG, optimizerD
def inception_score(imgs,
                    num_classes,
                    net,
                    cuda=True,
                    batch_size=32,
                    splits=1,
                    normalize_img=False):
    """Computes the inception score of the generated images imgs
    imgs -- unnormalized (3xHxW) numpy images
    net -- Classification CNN
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably set cuda=True"
            )
        dtype = torch.FloatTensor

    # Set up dataloader
    dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    # Load inception model
    if cuda:
        net = net.cuda()
    else:
        net = net.cpu()
    net.eval()

    def get_pred(x):
        x, _ = net(x)
        return F.softmax(x, dim=1).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, num_classes))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits):(k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)


# from torchvision.models.inception import inception_v3
# def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1, normalize_img=False):
#     """Computes the inception score of the generated images imgs based on Inception V3 which is pretrained on ImageNet
#     imgs -- unnormalized (3xHxW) numpy images
#     net -- Classification CNN
#     cuda -- whether or not to run on GPU
#     batch_size -- batch size for feeding into Inception v3
#     splits -- number of splits
#     """
#     N = len(imgs)

#     assert batch_size > 0
#     assert N > batch_size

#     # Set up dtype
#     if cuda:
#         dtype = torch.cuda.FloatTensor
#     else:
#         if torch.cuda.is_available():
#             print("WARNING: You have a CUDA device, so you should probably set cuda=True")
#         dtype = torch.FloatTensor

#     # Set up dataloader
#     dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img)
#     dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

#     # Load inception model
#     inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
#     inception_model.eval();
#     # up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
#     def get_pred(x):
#         if resize:
#             x = nn.functional.interpolate(x, size = (299, 299), scale_factor=None, mode='bilinear', align_corners=False)
#         x = inception_model(x)
#         return F.softmax(x, dim=1).data.cpu().numpy()

#     # Get predictions
#     preds = np.zeros((N, 1000))

#     for i, batch in enumerate(dataloader, 0):
#         batch = batch.type(dtype)
#         batchv = Variable(batch)
#         batch_size_i = batch.size()[0]

#         preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

#     # Now compute the mean kl-div
#     split_scores = []

#     for k in range(splits):
#         part = preds[k * (N // splits): (k+1) * (N // splits), :]
#         py = np.mean(part, axis=0)
#         scores = []
#         for i in range(part.shape[0]):
#             pyx = part[i, :]
#             scores.append(entropy(pyx, py))
#         split_scores.append(np.exp(np.mean(scores)))

#     return np.mean(split_scores), np.std(split_scores)
Beispiel #10
0
def inception_score(imgs,
                    num_classes,
                    net,
                    cuda=True,
                    batch_size=32,
                    splits=1,
                    normalize_img=False):
    """Computes the inception score of the generated images imgs
    imgs -- unnormalized (3xHxW) numpy images
    net -- Classification CNN
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably set cuda=True"
            )
        dtype = torch.FloatTensor

    # Set up dataloader
    dataset = IMGs_dataset(imgs, labels=None, normalize=normalize_img)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    # Load inception model
    if cuda:
        net = net.cuda()
    else:
        net = net.cpu()
    net.eval()

    def get_pred(x):
        x, _ = net(x)
        return F.softmax(x, dim=1).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, num_classes))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits):(k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)