def compute_inception_score(gan_model, ckpt_path=None):
    sess = gan_model.sess
    gan_model.load_generator(ckpt_path=ckpt_path)
    sess.run(tf.local_variables_initializer())

    print("Evaluating...")
    num_images_to_eval = 50000
    eval_images = []
    num_batches = num_images_to_eval // gan_model.batch_size + 1
    print("Calculating Inception Score. Sampling {} images...".format(
        num_images_to_eval))
    np.random.seed(0)
    for _ in range(num_batches):
        images = sess.run(gan_model.fake_data,
                          feed_dict={gan_model.is_training: False})
        eval_images.append(images)
    np.random.seed()
    eval_images = np.vstack(eval_images)
    eval_images = eval_images[:num_images_to_eval]
    eval_images = np.clip((eval_images + 1.0) * 127.5, 0.0,
                          255.0).astype(np.uint8)
    # Calc Inception score
    eval_images = list(eval_images)
    inception_score_mean, inception_score_std = get_inception_score(
        eval_images)
    print("Inception Score: Mean = {} \tStd = {}.".format(
        inception_score_mean, inception_score_std))
示例#2
0
def cifar_get_inception_score(session, samples_100):
    all_samples = []
    for i in xrange(10):
        all_samples.append(session.run(samples_100))
    all_samples = np.concatenate(all_samples, axis=0)
    all_samples = ((all_samples + 1.) * (255. / 2)).astype('int32')
    all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
    return get_inception_score(list(all_samples))
示例#3
0
def get_inception_score(G, ):
    all_samples = []
    for i in range(10):
        samples_100 = torch.randn(100, 128)
        if use_cuda:
            samples_100 = samples_100.cuda(gpu)
        samples_100 = autograd.Variable(samples_100, volatile=True)
        all_samples.append(G(samples_100).cpu().data.numpy())

    all_samples = np.concatenate(all_samples, axis=0)
    all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                              255).astype('int32')
    all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
    return inception_score.get_inception_score(list(all_samples))
示例#4
0
    def get_inception_score():
        all_samples = []
        samples = torch.randn(NUM_SAMPLES, 128)
        for i in range(0, NUM_SAMPLES, BATCH_SIZE):
            samples_100 = samples[i:i + BATCH_SIZE]
            if CUDA:
                samples_100 = samples_100.cuda(0)
            all_samples.append(gen(samples_100).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                                  255).astype('int32')
        all_samples = all_samples.reshape(
            (-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)
        return inception_score.get_inception_score(list(all_samples))
示例#5
0
def measure_gan(gan,
                rec_data_path=None,
                probe_size=10000,
                calc_real_data_is=True):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
    """
    FLAGS = tf.flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)
    sess = gan.sess

    # FID init
    stats_path = 'data/fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, False)

    images = train_images[
        0:probe_size] * 255  # np.concatenate(train_images, test_images)

    # Inception Score for real data
    is_orig_mean, is_orig_stddev = (-1, -1)
    if calc_real_data_is:
        is_orig_mean, is_orig_stddev = get_inception_score(images)
        print(
            '\n[#] Inception Score for original data: mean = %f, stddev = %f\n'
            % (is_orig_mean, is_orig_stddev))

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    # Calculate Inception Score for GAN
    gan.batch_size = probe_size
    generated_images_tensor = gan.generator_fn()
    generated_images = sess.run(generated_images_tensor)
    generated_images = 255 * ((generated_images + 1) / 2)
    is_gen_mean, is_gen_stddev = get_inception_score(generated_images)
    print(
        '\n[#] Inception Score for generated data: mean = %f, stddev = %f\n' %
        (is_gen_mean, is_gen_stddev))

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            generated_images, sess, batch_size=100)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)

    return is_gen_mean, is_gen_stddev, is_orig_mean, is_orig_stddev, fid_value
示例#6
0
def tester(args, model_bucket, test_fid, test_is, sess, mu_real, sigma_real,
           get_inception_metrics):
    # Get the requisite network
    use_pytorch = args.use_pytorch_scores
    net_type = args.netG
    device = args.device
    net = model_bucket
    input_dim = args.z_dim
    input_nc = args.input_nc
    crop_size = args.crop_size
    evaluation_size = args.test_size
    fid_batch_size = args.fid_batch_size

    scores_ret = OrderedDict()

    samples = torch.zeros((evaluation_size, 3, crop_size, crop_size),
                          device=device)
    n_fid_batches = evaluation_size // fid_batch_size

    for i in range(n_fid_batches):
        frm = i * fid_batch_size
        to = frm + fid_batch_size

        if net_type == 'DCGAN':
            z = torch.randn(fid_batch_size, input_dim, 1, 1, device=device)
        elif 'EGAN' in net_type:
            z = torch.rand(fid_batch_size, input_dim, 1, 1,
                           device=device) * 2. - 1.
        elif net_type == 'WGAN':
            z = torch.randn(fid_batch_size, input_dim, device=device)
        else:
            raise NotImplementedError('netG [%s] is not found' % net_type)

        gen_s = net(z).detach()
        samples[frm:to] = gen_s
        print("\rgenerate fid sample batch %d/%d " % (i + 1, n_fid_batches),
              end="",
              flush=True)

    print("%d samples generating done" % evaluation_size)

    if use_pytorch:
        IS_mean, IS_var, FID = get_inception_metrics(samples,
                                                     evaluation_size,
                                                     num_splits=10)

        if test_fid:
            scores_ret['FID'] = float(FID)
        if test_is:
            scores_ret['IS_mean'] = float(IS_mean)
            scores_ret['IS_var'] = float(IS_var)

    else:
        samples = samples.cpu().numpy()
        samples = ((samples + 1.0) * 127.5).astype('uint8')
        samples = samples.reshape(evaluation_size, input_nc, crop_size,
                                  crop_size)
        samples = samples.transpose(0, 2, 3, 1)

        if test_fid:
            mu_gen, sigma_gen = fid.calculate_activation_statistics(
                samples, sess, batch_size=fid_batch_size, verbose=True)
            print("calculate FID:")
            try:
                FID = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_real, sigma_real)
            except Exception as e:
                print(e)
                FID = 500
            scores_ret['FID'] = float(FID)
        if test_is:
            Imlist = []
            for i in range(len(samples)):
                im = samples[i, :, :, :]
                Imlist.append(im)
            print(np.array(Imlist).shape)
            IS_mean, IS_var = get_inception_score(Imlist)
            scores_ret['IS_mean'] = float(IS_mean)
            scores_ret['IS_var'] = float(IS_var)

    return scores_ret