def get_IS_and_FID(n):
     all_samples = []
     for i in xrange(n/100):
         all_samples.append(session.run(samples_100))
     all_samples = np.concatenate(all_samples, axis=0)
     all_samples = ((all_samples+1.)*(255.99/2)).astype('int32')
     all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1)
     # print("getting IS and FID")
     # print(_iteration_gan)
     # embed()
     # with tf.Session() as _fid_session:
     #     _inception_score, _fid_score = fid.calc_IS_and_FID(all_samples, (mu_real, sigma_real), 100, verbose=True, session=_fid_session)
     # _inception, _inception_std = _inception_score
     _fid_score = 0
     _inception_score = lib.inception_score.get_inception_score(list(all_samples))
     print("calculated IS")
     # print(_inception_score, _inception_score_check)
     # embed()
     # assert(_inception_score[0] == _inception_score_check[0])
     # print("IS calculation same as old")
     if run_fid:
         mu_gen, sigma_gen = fid.calculate_activation_statistics(all_samples, session, 100, verbose=True)
         try:
             _fid_score = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
         except Exception as e:
             print(e)
             _fid_score = 10e4
         print("calculated IS and FID")
         return _inception_score, _fid_score
     else:
         return _inception_score, 0
def cal_fid_score(G, device, z_dim):
    stats_path = 'tflib/data/fid_stats_lsun_train.npz'
    inception_path = fid.check_or_download_inception('tflib/model')
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    fid.create_inception_graph(inception_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    all_samples = []
    samples = torch.randn(NUM_SAMPLES, z_dim, 1, 1)
    for i in range(0, NUM_SAMPLES, BATCH_SIZE):
        samples_100 = samples[i:i + BATCH_SIZE]
        samples_100 = samples_100.to(device=device)
        all_samples.append(G(samples_100).cpu().data.numpy())

    all_samples = np.concatenate(all_samples, axis=0)
    all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                              255).astype('int32')
    all_samples = all_samples.reshape(
        (-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            all_samples, sess, batch_size=BATCH_SIZE)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    return fid_value
Exemplo n.º 3
0
    def get_fid_score():
        all_samples = []
        samples = torch.randn(N_SAMPLES, N_LATENT)
        for i in xrange(0, N_SAMPLES, BATCH_SIZE):
            samples_100 = samples[i:i+BATCH_SIZE]
            if CUDA:
                samples_100 = samples_100.cuda(0)
            all_samples.append(gen(samples_100).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5), 255).astype('int32')
        all_samples = all_samples.reshape((-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)

        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            mu_gen, sigma_gen = fid.calculate_activation_statistics(all_samples, sess, batch_size=BATCH_SIZE)

        fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
        return fid_value
Exemplo n.º 4
0
def measure_gan(gan,
                rec_data_path=None,
                probe_size=10000,
                calc_real_data_is=True):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
    """
    FLAGS = tf.flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)
    sess = gan.sess

    # FID init
    stats_path = 'data/fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, False)

    images = train_images[
        0:probe_size] * 255  # np.concatenate(train_images, test_images)

    # Inception Score for real data
    is_orig_mean, is_orig_stddev = (-1, -1)
    if calc_real_data_is:
        is_orig_mean, is_orig_stddev = get_inception_score(images)
        print(
            '\n[#] Inception Score for original data: mean = %f, stddev = %f\n'
            % (is_orig_mean, is_orig_stddev))

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    # Calculate Inception Score for GAN
    gan.batch_size = probe_size
    generated_images_tensor = gan.generator_fn()
    generated_images = sess.run(generated_images_tensor)
    generated_images = 255 * ((generated_images + 1) / 2)
    is_gen_mean, is_gen_stddev = get_inception_score(generated_images)
    print(
        '\n[#] Inception Score for generated data: mean = %f, stddev = %f\n' %
        (is_gen_mean, is_gen_stddev))

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            generated_images, sess, batch_size=100)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)

    return is_gen_mean, is_gen_stddev, is_orig_mean, is_orig_stddev, fid_value
Exemplo n.º 5
0
def tester(args, model_bucket, test_fid, test_is, sess, mu_real, sigma_real,
           get_inception_metrics):
    # Get the requisite network
    use_pytorch = args.use_pytorch_scores
    net_type = args.netG
    device = args.device
    net = model_bucket
    input_dim = args.z_dim
    input_nc = args.input_nc
    crop_size = args.crop_size
    evaluation_size = args.test_size
    fid_batch_size = args.fid_batch_size

    scores_ret = OrderedDict()

    samples = torch.zeros((evaluation_size, 3, crop_size, crop_size),
                          device=device)
    n_fid_batches = evaluation_size // fid_batch_size

    for i in range(n_fid_batches):
        frm = i * fid_batch_size
        to = frm + fid_batch_size

        if net_type == 'DCGAN':
            z = torch.randn(fid_batch_size, input_dim, 1, 1, device=device)
        elif 'EGAN' in net_type:
            z = torch.rand(fid_batch_size, input_dim, 1, 1,
                           device=device) * 2. - 1.
        elif net_type == 'WGAN':
            z = torch.randn(fid_batch_size, input_dim, device=device)
        else:
            raise NotImplementedError('netG [%s] is not found' % net_type)

        gen_s = net(z).detach()
        samples[frm:to] = gen_s
        print("\rgenerate fid sample batch %d/%d " % (i + 1, n_fid_batches),
              end="",
              flush=True)

    print("%d samples generating done" % evaluation_size)

    if use_pytorch:
        IS_mean, IS_var, FID = get_inception_metrics(samples,
                                                     evaluation_size,
                                                     num_splits=10)

        if test_fid:
            scores_ret['FID'] = float(FID)
        if test_is:
            scores_ret['IS_mean'] = float(IS_mean)
            scores_ret['IS_var'] = float(IS_var)

    else:
        samples = samples.cpu().numpy()
        samples = ((samples + 1.0) * 127.5).astype('uint8')
        samples = samples.reshape(evaluation_size, input_nc, crop_size,
                                  crop_size)
        samples = samples.transpose(0, 2, 3, 1)

        if test_fid:
            mu_gen, sigma_gen = fid.calculate_activation_statistics(
                samples, sess, batch_size=fid_batch_size, verbose=True)
            print("calculate FID:")
            try:
                FID = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_real, sigma_real)
            except Exception as e:
                print(e)
                FID = 500
            scores_ret['FID'] = float(FID)
        if test_is:
            Imlist = []
            for i in range(len(samples)):
                im = samples[i, :, :, :]
                Imlist.append(im)
            print(np.array(Imlist).shape)
            IS_mean, IS_var = get_inception_score(Imlist)
            scores_ret['IS_mean'] = float(IS_mean)
            scores_ret['IS_var'] = float(IS_var)

    return scores_ret