Exemple #1
0
def main():
    # Paths
    pic_path = os.path.join('./out/c/', 'checkpoints',
                            'dummy.samples%d.npy' % (NUM_POINTS))
    image_path = 'results_celeba/generated'  # set path to some generated images
    stats_path = 'fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    #image_list = glob.glob(os.path.join(image_path, '*.png'))
    #images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
    images = np.load(pic_path)

    images_t = images / 2.0 + 0.5
    images_t = 255.0 * images_t

    from PIL import Image
    img = Image.fromarray(np.uint8(images_t[0]), 'RGB')
    img.save('my.png')

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)
Exemple #2
0
 def compute(images):
     m, s = fid.calculate_activation_statistics(transform_for_fid(images),
                                                inception_sess,
                                                args.batch_size,
                                                verbose=True,
                                                model='lenet')
     return fid.calculate_frechet_distance(m, s, mu0, sig0)
Exemple #3
0
def main():
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)

    data_files = glob.glob(os.path.join("./img_align_celeba", "*.jpg"))
    data_files = sorted(data_files)[:10000]
    data_files = np.array(data_files)
    images = np.array([get_image(data_file, 148)
                       for data_file in data_files]).astype(np.float32)
    images = images * 255

    output_name = 'fid_stats_face'

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    print("ok")

    print("calculte FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics(images,
                                                        sess,
                                                        batch_size=100)
        np.savez_compressed(output_name, mu=mu, sigma=sigma)
    print("finished")
Exemple #4
0
    def train(self, num_batches=200000):
        self.sess.run(tf.global_variables_initializer())
        mean_list = []
        fid_list = []
        start_time = time.time()
        for t in range(0, num_batches):
            for _ in range(0, self.d_iters):
                bx = self.x_sampler(self.batch_size)
                bz = self.z_sampler(self.batch_size, self.z_dim)
                self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz})

            bx = self.x_sampler(self.batch_size)
            bz = self.z_sampler(self.batch_size, self.z_dim)
            self.sess.run([self.g_adam], feed_dict={self.z: bz, self.x: bx})

            if t % 1000 == 0:
                bx = self.x_sampler(self.batch_size)
                bz = self.z_sampler(self.batch_size, self.z_dim)
                dl, gl, gp, x_ = self.sess.run(
                    [self.d_loss, self.g_loss, self.gp_loss, self.x_],
                    feed_dict={
                        self.x: bx,
                        self.z: bz
                    })

                print('Iter [%8d] Time [%.4f] dl [%.4f] gl [%.4f] gp [%.4f]' %
                      (t, time.time() - start_time, dl, gl, gp))

                x_ = self.x_sampler.data2img(x_)
                x_ = grid_transform(x_, self.x_sampler.shape)
                imsave(self.log_dir + '/wos/{}.png'.format(int(t)), x_)

            if t % 10000 == 0 and t > 0:
                in_list = []
                for _ in range(int(50000 / self.batch_size)):
                    bz = self.z_sampler(self.batch_size, self.z_dim)
                    x_ = self.sess.run(self.x_, feed_dict={self.z: bz})
                    x_ = self.x_sampler.data2img(x_)
                    bx_list = np.split(x_, self.batch_size)
                    in_list = in_list + [np.squeeze(x) for x in bx_list]
                mean, std = self.inception.get_inception_score(in_list,
                                                               splits=10)
                mean_list.append(mean)
                np.save(self.log_dir + '/inception_score_wgan_gp.npy',
                        np.asarray(mean_list))
                print('inception score [%.4f]' % (mean))

            if t % 10000 == 0 and t > 0 and args.fid:
                f = np.load(self.stats_path)
                mu_real, sigma_real = f['mu'][:], f['sigma'][:]
                f.close()

                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    np.array(in_list[:10000]), self.sess, batch_size=100)
                fid_value = fid.calculate_frechet_distance(
                    mu_gen, sigma_gen, mu_real, sigma_real)
                print("FID: %s" % fid_value)
                fid_list.append(fid_value)
                np.save(self.log_dir + '/fid_score_wgan_gp.npy',
                        np.asarray(fid_list))
def calculate_stats(imageset, batch_size=default_batchsize, printTime=False):
    ###
    # inception-net expects shape (n,w,h,3)
    ###
    # if shape == (n,w,h), reshape to (n,w,h,1)
    if len(imageset.shape) < 4:  # need shape: (n, height, width, 3)
        # reshape to (n,h,w,1)
        shape = list(imageset.shape)
        shape.append(1)
        imageset = np.reshape(imageset, shape)
    ###
    # if shape == (n,w,h,1), duplicate channel: -> (n,w,h,3)
    if imageset.shape[3] == 1:
        # repeat channel
        imageset = np.repeat(imageset, 3, axis=-1)

    starttime = time()
    with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics(imageset,
                                                        sess,
                                                        batch_size=batch_size)
    if printTime:
        print("calculating rlts took %f seconds" % (time() - starttime))
    return (mu, sigma)
Exemple #6
0
def load_fid(mnist_test_images, args, binarize=True):
    import fid

    def transform_for_fid(im):
        assert len(im.shape) == 2 and im.dtype == np.float32
        if binarize:
            im = (im > np.random.random(size=im.shape)).astype(np.float32)
        a = np.array(im) - 0.5
        return a.reshape((-1, 28, 28, 1))

    inception_path = os.path.expanduser('~/lenet/savedmodel')
    inception_graph = tf.Graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    inception_sess = tf.Session(config=config, graph=inception_graph)
    with inception_graph.as_default():
        tf.saved_model.loader.load(inception_sess,
                                   [tf.saved_model.tag_constants.TRAINING],
                                   inception_path)
    mu0, sig0 = fid.calculate_activation_statistics(
        transform_for_fid(mnist_test_images),
        inception_sess,
        args.batch_size,
        verbose=True,
        model='lenet')

    def compute(images):
        m, s = fid.calculate_activation_statistics(transform_for_fid(images),
                                                   inception_sess,
                                                   args.batch_size,
                                                   verbose=True,
                                                   model='lenet')
        return fid.calculate_frechet_distance(m, s, mu0, sig0)

    return compute, locals()
Exemple #7
0
 def compute_fid(self, images, inception_path, batch_size=100):
     g = tf.Graph()
     with g.as_default():
         fid.create_inception_graph(
             inception_path)  # load the graph into the current TF graph
         sess = tf.Session(graph=g)
         mu, sigma = fid.calculate_activation_statistics(
             images, sess, batch_size=batch_size)
         sess.close()
     return mu, sigma
Exemple #8
0
 def fid_ms_for_imgs(images, mem_fraction=0.5):
     gpu_options = tf.GPUOptions(
         per_process_gpu_memory_fraction=mem_fraction)
     inception_path = fid.check_or_download_inception(None)
     fid.create_inception_graph(
         inception_path)  # load the graph into the current TF graph
     with tf.Session(config=tf.ConfigProto(
             gpu_options=gpu_options)) as sess:
         sess.run(tf.global_variables_initializer())
         mu_gen, sigma_gen = fid.calculate_activation_statistics(
             images, sess, batch_size=100)
     return mu_gen, sigma_gen
def _get_statistics(stat_root, data, image_shape, inception_sess):
    os.makedirs(stat_root, exist_ok=True)
    mu_path = os.path.join(stat_root, 'ac_mu.npy')
    sigma_path = os.path.join(stat_root, 'ac_sigma.npy')
    if os.path.exists(mu_path) and os.path.exists(sigma_path):
        print('Using cached activation statistics')
        mu = np.load(mu_path)
        sigma = np.load(sigma_path)
    else:
        image = _maybe_grayscale_to_rgb(np.reshape(data, (-1, ) + image_shape))
        image = (image + 1.0) / 2.0 * 255.0
        mu, sigma = fid.calculate_activation_statistics(image, inception_sess)
        np.save(mu_path, mu)
        np.save(sigma_path, sigma)
    return mu, sigma
Exemple #10
0
def main(model, data_source, noise_method, noise_factors, lambdas):
    """
    model: RVAE or VAE
    data_source: data set of training. Either 'MNIST' or 'FASHION'
    noise_method: method of adding noise. Either 'sp' (represents salt-and-pepper) 
                  or 'gs' (represents Gaussian)
    noise_factors: noise factors
    lambdas: lambda
    """
    
    input_path = "../output/"+model+"_"+data_source+"_"+noise_method+"/"
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary
    print("ok")
    
    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" " , flush=True)
    
    output_path = "fid_precalc/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    output_path = output_path+model+"_"+data_source+"_"+noise_method+"/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    
    for l in lambdas:
        for nr in noise_factors:
            if model == 'RVAE':
                data_path = input_path+'lambda_'+str(l)+'/noise_'+str(nr)+'/generation_fid.npy'
                output_name = 'fid_stats_lambda_'+str(l)+'noise_'+str(nr)
            else:
                data_path = input_path+str(nr)+'/generation_fid.npy'
                output_name = 'fid_stats_noise_'+str(nr)
            images = np.load(data_path)[:10000]
            images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1)
            
            print("create inception graph..", end=" ", flush=True)
            fid.create_inception_graph(inception_path)  # load the graph into the current TF graph
            print("ok")
            
            print("calculte FID stats..", end=" ", flush=True)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100)
                np.savez_compressed(output_path+output_name, mu=mu, sigma=sigma)
            print("finished")
Exemple #11
0
    def compute_fid_score(self, generator, timestamp):
        """
        Computes FID of generator using fixed noise dataset;
        appends the current score to the list of computed scores;
        and overwrites the json file that logs the fid scores.

        :param generator: [nn.Module]
        :param timestamp: [int]
        :return: None
        """
        generator.eval()
        fake_samples = np.empty(
            (self.sample_size_fid, self.imsize, self.imsize, 3))
        for j, noise in enumerate(self.fid_noise_loader):
            noise = noise.cuda()
            i1 = j * 200  # batch_size = 200
            i2 = i1 + noise.size(0)
            samples = generator(noise).cpu().data.add(1).mul(255 / 2.0)
            fake_samples[i1:i2] = samples.permute(0, 2, 3, 1).numpy()
        generator.train()
        mu_g, sigma_g = fid.calculate_activation_statistics(fake_samples,
                                                            self.fid_session,
                                                            batch_size=100)
        fid_score = fid.calculate_frechet_distance(mu_g, sigma_g, self.mu_real,
                                                   self.sigma_real)
        _result = {
            'entry': len(self.fid_scores),
            'iter': timestamp,
            'fid': fid_score
        }

        # if best update the checkpoint in self.best_path
        new_best = True
        for prev_fid in self.fid_scores:
            if prev_fid['fid'] < fid_score:
                new_best = False
                break
        if new_best:
            self.backup(timestamp, dir=self.best_path)

        self.fid_scores.append(_result)
        with open(self.fid_json_file, 'w') as _f_fid:
            json.dump(self.fid_scores,
                      _f_fid,
                      sort_keys=True,
                      indent=4,
                      separators=(',', ': '))
Exemple #12
0
def generate(args):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    if not os.path.exists(CIFAR_STATS_PATH):
        print('Generating FID statistics for test set...')
        print('Building Inception graph')
        with tf.Session(config=config) as sess:
            inception_path = fid.check_or_download_inception(INCEPTION_PATH)
            fid.create_inception_graph(str(inception_path))
            ds = datasets.load_cifar10(True)
            all_test_set = (ds.test.images + 1) * 128
            print(all_test_set.shape)
            m, s = fid.calculate_activation_statistics(
                all_test_set, sess, args.batch_size, verbose=True)
        np.savez(CIFAR_STATS_PATH, mu=m, sigma=s)
        print('Done')

    root_dir = os.path.dirname(args.dir)
    args_json = json.load(open(os.path.join(root_dir, 'hps.txt')))
    ckpt_dir = args.dir
    vars(args).update(args_json)

    model_graph = tf.Graph()
    with model_graph.as_default():
        x_ph, is_training_ph, model, optimizer, batch_size_sym, z_sample_sym, x_sample_sym = build_graph(args)
        saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=3, max_to_keep=6)

    model_sess = tf.Session(config=config, graph=model_graph)
    print('RESTORING MODEL FROM', ckpt_dir)
    saver.restore(model_sess, ckpt_dir)
    compute_fid, _ = load_fid(args)
    images = []
    for j in range(100):
        x_samples = model_sess.run(x_sample_sym, {batch_size_sym: 100, is_training_ph: False})
        x_samples = (np.clip(x_samples, -1, 1) + 1) / 2 * 256
        images.extend(x_samples)

    fscore = compute_fid(images)
    print('FID score = {}'.format(fscore))
    
    dest = os.path.join(root_dir, 'generated')
    if not os.path.exists(dest):
        os.makedirs(dest)
    for j, im in enumerate(images):
        plt.imsave(os.path.join(dest, '{}.png'.format(j)), im/256)
Exemple #13
0
def main(model, noise_factors, lambdas):
    """
    model: RVAE or VAE
    noise_factors: noise factors
    lambdas: lambda
    """

    input_path = model
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)

    output_path = "fid_precalc/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)

    for l in lambdas:
        for nr in noise_factors:
            data_path = input_path + 'lambda_' + str(l) + '/noise_' + str(
                nr) + '/generation_fid.npy'
            output_name = 'fid_stats_lambda_' + str(l) + 'noise_' + str(nr)
            images = np.load(data_path)
            images = np.transpose(images * 255, (0, 2, 3, 1))
            #images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1)

            print("create inception graph..", end=" ", flush=True)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            print("ok")

            print("calculte FID stats..", end=" ", flush=True)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images,
                                                                sess,
                                                                batch_size=100)
                np.savez_compressed(output_path + output_name,
                                    mu=mu,
                                    sigma=sigma)
            print("finished")
def main(args):

	device = 'cuda'

	print('Loading ResNext101 model...')
	model = nn.DataParallel(resnet101(sample_duration=16).cuda())
	model.load_state_dict(torch.load('resnext-101-kinetics.pth')['state_dict'])

	print('Loading video paths...')

	if args.dataset == 'uva':
		files = glob.glob(args.data_path + '/*.mp4')
		data_type = 'video'
	else:
		raise NotImplementedError
	mu, sigma = fid.calculate_activation_statistics(files, data_type, model, args.batch_size, args.size, args.length, args.dims, device)
	np.savez_compressed('./stats/'+args.dataset+'.npz', mu=mu, sigma=sigma)

	print('finished')
Exemple #15
0
def load_fid(dtest, args):
    import fid

    def transform_for_fid(im):
        assert len(im.shape) == 4 and im.dtype == np.float32
        if im.shape[-1] == 1:
            assert im.shape[-2] == 28
            im = np.tile(im, [1, 1, 1, 3])
        if not (im.std() < 1. and im.min() > -1.):
            print('WARNING: abnormal image range', im.std(), im.min())
        return (im + 1) * 128

    inception_path = fid.check_or_download_inception(INCEPTION_PATH)
    inception_graph = tf.Graph()
    with inception_graph.as_default():
        fid.create_inception_graph(str(inception_path))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    inception_sess = tf.Session(config=config, graph=inception_graph)

    stats_path = os.path.join(INCEPTION_PATH, f'{args.dataset}-stats.npz')
    if not os.path.exists(stats_path):
        mu0, sig0 = fid.calculate_activation_statistics(
            transform_for_fid(dtest),
            inception_sess,
            args.batch_size,
            verbose=True)
        np.savez(stats_path, mu0=mu0, sig0=sig0)
    else:
        sdict = np.load(stats_path)
        mu0, sig0 = sdict['mu0'], sdict['sig0']

    def compute(images):
        m, s = fid.calculate_activation_statistics(transform_for_fid(images),
                                                   inception_sess,
                                                   args.batch_size,
                                                   verbose=True)
        return fid.calculate_frechet_distance(m, s, mu0, sig0)

    return compute, locals()
def _run_fid_calculation(sess, inception_sess, placeholders, batch_size,
                         iteration, generator, mu, sigma, epoch, image_shape,
                         z_input_shape, y_input_shape):
    f = 0.0
    for _ in range(iteration):
        z = util.gen_random_noise(batch_size, z_input_shape)
        y = util.gen_random_label(batch_size, y_input_shape[0])

        images = sess.run(
            tf.reshape(generator, (-1, ) + image_shape), {
                placeholders['z']: z,
                placeholders['y']: y,
                placeholders['mode']: False,
            })
        images = _maybe_grayscale_to_rgb(images)
        images = (images + 1.0) / 2.0 * 255.0

        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            images, inception_sess)
        f += fid.calculate_frechet_distance(mu, sigma, mu_gen, sigma_gen)
    return f / iteration
Exemple #17
0
def precalc(data_path, output_path):
    print("CALCULATING THE GT STATS....")
    # data_path = 'reconstructed_test/eval' # set path to training set images
    # output_path = data_path+'/fid_stats.npz' # path for where to store the statistics
    # if you have downloaded and extracted
    #   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
    # set this path to the directory where the extracted files are, otherwise
    # just set it to None and the script will later download the files for you
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)
    image_list = glob.glob(os.path.join(data_path, '*.jpg'))
    if len(image_list) == 0:
        print("No images in directory ", data_path)
        return

    images = np.array([
        imageio.imread(str(fn), as_gray=False,
                       pilmode="RGB").astype(np.float32) for fn in image_list
    ])
    print("%d images found and loaded" % len(images))

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    print("ok")

    print("calculte FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma, acts = fid.calculate_activation_statistics(
            images, sess, batch_size=BATCH_SIZE)
        np.savez_compressed(output_path, mu=mu, sigma=sigma, activations=acts)
    print("finished")
Exemple #18
0
            t += 1
            elapsed_time = datetime.datetime.now() - start_time
            print(str(t), "/", test_len, ": ", blender_filename,
                  pc2pix_filename, "Elapsed :", elapsed_time)
            print(np.array(gt).shape)

    gt = np.array(gt)
    bl = np.array(bl)
    pc = np.array(pc)

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gt, sigma_gt = fid.calculate_activation_statistics(gt, sess)
        mu_bl, sigma_bl = fid.calculate_activation_statistics(bl, sess)
        mu_pc, sigma_pc = fid.calculate_activation_statistics(pc, sess)

    fid_value = fid.calculate_frechet_distance(mu_bl, sigma_bl, mu_gt,
                                               sigma_gt)
    filename = "fid.log"
    fd = open(filename, "a+")
    fd.write("---| ")
    fd.write(args.split_file)
    fd.write(" |---\n")
    print("Surface FID: %s" % fid_value)
    fd.write("Surface FID: %s\n" % fid_value)
    fid_value = fid.calculate_frechet_distance(mu_pc, sigma_pc, mu_gt,
                                               sigma_gt)
    print("PC2PIX FID: %s" % fid_value)
Exemple #19
0
                #            '%s/Epoch_(%d)_(%dof%d)_img_rec.png' % (save_dir, ep, it_in_epoch, it_per_epoch))
                im.imwrite(
                    im.immerge(img_intp_opt_sample, n_col=1, padding=0),
                    '%s/Epoch_(%d)_(%dof%d)_img_intp.png' %
                    (save_dir, ep, it_in_epoch, it_per_epoch))
                im.imwrite(
                    im.immerge(img_opt_sample),
                    '%s/Epoch_(%d)_(%dof%d)_img_sample.png' %
                    (save_dir, ep, it_in_epoch, it_per_epoch))

                if fid_stats_path:
                    try:
                        mu_gen, sigma_gen = fid.calculate_activation_statistics(
                            im.im2uint(
                                np.concatenate([
                                    sess.run(fid_sample).squeeze()
                                    for _ in range(5)
                                ], 0)),
                            sess,
                            batch_size=100)
                        fid_value = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_real, sigma_real)
                    except:
                        fid_value = -1.
                    fid_summary = tf.Summary()
                    fid_summary.value.add(tag='FID', simple_value=fid_value)
                    summary_writer.add_summary(fid_summary, it)
                    print("FID: %s" % fid_value)

        save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
        print('Model is saved in file: %s' % save_path)
except:
    image_list = h5py.File(os.path.join(data_path, 'test_x.h5'),
                           'r',
                           swmr=True)['x']
    # print(image_list[10])
    # exit(0)
    # images = np.array([files[index, ...].astype(np.float32) for index in range(len(files))])
    output_path = 'fid_stats_pcam.npz'  # path for where to store the statistics
else:
    print("Unsupported dataset")
    sys.exit(1)
print("%d images found" %
      len(image_list))  # These are not fully read into memory though

print("create inception graph..", end=" ", flush=True)
fid.create_inception_graph(
    inception_path)  # load the graph into the current TF graph
print("ok")

print("calculte FID stats..", end=" ", flush=True)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if config.dataset == "CelebA":
        mu, sigma = fid.calculate_activation_statistics(image_list,
                                                        sess,
                                                        batch_size=100)
    elif config.dataset == "PCam":
        mu, sigma = fid.calculate_activation_statistics(image_list,
                                                        sess,
                                                        batch_size=100)
    np.savez_compressed(output_path, mu=mu, sigma=sigma)
print("finished")
                      end="",
                      flush=True)

                frm = i * FID_SAMPLE_BATCH_SIZE
                to = frm + FID_SAMPLE_BATCH_SIZE

                samples[frm:to] = session.run(Generator(FID_SAMPLE_BATCH_SIZE))

            # Cast, reshape and transpose (BCHW -> BHWC)
            samples = ((samples + 1.0) * 127.5).astype('uint8')
            samples = samples.reshape(FID_EVAL_SIZE, 3, DIM, DIM)
            samples = samples.transpose(0, 2, 3, 1)

            print("ok")

            mu_gen, sigma_gen = fid.calculate_activation_statistics(
                samples, session, batch_size=FID_BATCH_SIZE, verbose=True)

            print("calculate FID:", end=" ", flush=True)
            try:
                FID = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_real, sigma_real)
            except Exception as e:
                print(e)
                FID = 500

            print(FID)

            session.run(tf.assign(fid_tfvar, FID))
            summary_str = session.run(fid_sum)
            writer.add_summary(summary_str, iteration)
    def calculate_fid(self):
        import fid
        import tensorflow as tf

        num_of_step = 500
        bs = 100

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))
        stats_path = 'fid_stats_cifar10_train.npz'  # training set statistics
        inception_path = fid.check_or_download_inception(
            None)  # download inception network

        print('Load checkpoint from' + self.args.log)
        #for epochs in range(140000, 200001, 1000):
        for epochs in [149000]:
            states = torch.load(os.path.join(
                self.args.log, 'checkpoint_' + str(epochs) + '.pth'),
                                map_location=self.config.device)
            #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
            score = CondRefineNetDilated(self.config).to(self.config.device)
            score = torch.nn.DataParallel(score)

            score.load_state_dict(states[0])

            score.eval()

            if self.config.data.dataset == 'MNIST':
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 1, 28, 28, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            else:
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 3, 32, 32, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f['mu'][:], f['sigma'][:]
            f.close()

            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)

            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            print("FID: %s" % fid_value)
    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)
    if args.images_png_path is not None:
        image_list = []
        for path in args.images_png_path:
            image_list.extend(glob.glob(os.path.join(args.data_path, '*.png')))
        images = np.array(
            [imread(str(fn)).astype(np.float32) for fn in image_list])
    elif args.images_npy_path is not None:
        images = []
        for path in args.images_npy_path:
            with open(path, 'rb') as f:
                images.append(load_nist_images(np.load(f)))
        images = np.vstack(images)

    print(images.shape)
    print("%d images found and loaded" % len(images))

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    print("ok")

    print("calculte FID stats..", end=" ", flush=True)
    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics(images,
                                                        sess,
                                                        batch_size=256)
        np.savez_compressed(args.output_path, mu=mu, sigma=sigma)
    print("finished")
Exemple #24
0
#stats_path = '/home/minje/dev/dataset/stl/fid_stats_stl10.npz' # training set statistics (maybe pre-calculated)
inception_path = fid.check_or_download_inception(None) # download inception network

# precalculate training set statistics
# #image_files = glob.glob(os.path.join('/home/minje/dev/dataset/cifar/cifar-10-images', '*.jpg'))
# image_files = glob.glob(os.path.join('/home/minje/dev/dataset/stl/images', '*.jpg'))
# fid.create_inception_graph(inception_path)
# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     mu_real, sigma_real = fid.calculate_activation_statistics_from_files(image_files, sess,
#         batch_size=100, verbose=True)
# np.savez(stats_path, mu=mu_real, sigma=sigma_real)
# exit(0)

# loads all images into memory (this might require a lot of RAM!)
image_files = glob.glob(os.path.join(image_path, '*.jpg'))
images = np.array([imread(str(fn)).astype(np.float32) for fn in image_files])

# load precalculated training set statistics
f = np.load(stats_path)
mu_real, sigma_real = f['mu'][:], f['sigma'][:]
f.close()

fid.create_inception_graph(inception_path)  # load the graph into the current TF graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess, batch_size=100, verbose=True)

fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
print("FID: %s" % fid_value)
    def train(self, config):
        """Train DCGAN"""

        print("load train stats.. ", end="", flush=True)
        # load precalculated training set statistics
        f = np.load(self.stats_path)
        mu_real, sigma_real = f['mu'][:], f['sigma'][:]
        f.close()
        print("ok")

        if config.dataset == 'mnist':
            print("scan files", end=" ", flush=True)
            data_X, data_y = self.load_mnist()
        else:
            if (config.dataset == "celebA") or (config.dataset == "cifar10"):
                print("scan files", end=" ", flush=True)
                data = glob(
                    os.path.join(self.data_path, self.input_fname_pattern))
            else:
                if config.dataset == "lsun":
                    print("scan files")
                    data = []
                    for i in range(304):
                        print("\r%d" % i, end="", flush=True)
                        data += glob(
                            os.path.join(self.data_path, str(i),
                                         self.input_fname_pattern))
                else:
                    print(
                        "Please specify dataset in run.sh [mnist, celebA, lsun, cifar10]"
                    )
                    raise SystemExit()

        print()
        print("%d images found" % len(data))

        # Z sample
        #sample_z = np.random.normal(0, 1.0, size=(self.sample_num , self.z_dim))
        sample_z = np.random.uniform(-1.0,
                                     1.0,
                                     size=(self.sample_num, self.z_dim))

        # Input samples
        sample_files = data[0:self.sample_num]
        sample = [
            get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      is_crop=self.is_crop,
                      is_grayscale=self.is_grayscale)
            for sample_file in sample_files
        ]
        if (self.is_grayscale):
            sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
        else:
            sample_inputs = np.array(sample).astype(np.float32)

        if self.load_checkpoint:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Batch preparing
        batch_nums = min(len(data), config.train_size) // config.batch_size
        data_idx = list(range(len(data)))

        counter = self.counter_start

        start_time = time.time()

        # Loop over epochs
        for epoch in range(config.epoch):

            # Assign learning rates for d and g
            lrate = config.learning_rate_d  # * (config.lr_decay_rate_d ** epoch)
            self.sess.run(tf.assign(self.learning_rate_d, lrate))
            lrate = config.learning_rate_g  # * (config.lr_decay_rate_g ** epoch)
            self.sess.run(tf.assign(self.learning_rate_g, lrate))

            # Shuffle the data indices
            np.random.shuffle(data_idx)

            # Loop over batches
            for batch_idx in range(batch_nums):

                # Prepare batch
                idx = data_idx[batch_idx * config.batch_size:(batch_idx + 1) *
                               config.batch_size]
                batch = [
                    get_image(data[i],
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              is_crop=self.is_crop,
                              is_grayscale=self.is_grayscale) for i in idx
                ]
                if (self.is_grayscale):
                    batch_images = np.array(batch).astype(np.float32)[:, :, :,
                                                                      None]
                else:
                    batch_images = np.array(batch).astype(np.float32)

                #batch_z = np.random.normal(0, 1.0, size=(config.batch_size , self.z_dim)).astype(np.float32)
                batch_z = np.random.uniform(
                    -1.0, 1.0,
                    size=(config.batch_size, self.z_dim)).astype(np.float32)

                # Update D network
                _, summary_str = self.sess.run([self.d_optim, self.d_sum],
                                               feed_dict={
                                                   self.inputs: batch_images,
                                                   self.z: batch_z
                                               })
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                # Update G network
                _, summary_str = self.sess.run([self.g_optim, self.g_sum],
                                               feed_dict={self.z: batch_z})
                if np.mod(counter, 20) == 0:
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.inputs: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                # Print
                if np.mod(counter, 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, batch_idx, batch_nums, time.time() - start_time, errD_fake+errD_real, errG))

                # Save generated samples and FID
                if np.mod(counter, config.fid_eval_steps) == 0:

                    # Save
                    try:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_inputs
                            })
                        save_images(
                            samples, [8, 8],
                            '{}/train_{:02d}_{:04d}.png'.format(
                                config.sample_dir, epoch, batch_idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    except Exception as e:
                        print(e)
                        print("sample image error!")

                    # FID
                    print("samples for incept", end="", flush=True)

                    samples = np.zeros((self.fid_n_samples, self.output_height,
                                        self.output_width, 3))
                    n_batches = self.fid_n_samples // self.fid_sample_batchsize
                    lo = 0
                    for btch in range(n_batches):
                        print("\rsamples for incept %d/%d" %
                              (btch + 1, n_batches),
                              end=" ",
                              flush=True)
                        #sample_z_fid = np.random.normal(0, 1.0, size=(self.fid_sample_batchsize, self.z_dim))
                        sample_z_fid = np.random.uniform(
                            -1.0,
                            1.0,
                            size=(self.fid_sample_batchsize, self.z_dim))
                        samples[lo:(
                            lo + self.fid_sample_batchsize)] = self.sess.run(
                                self.sampler_fid,
                                feed_dict={self.z_fid: sample_z_fid})
                        lo += self.fid_sample_batchsize

                    samples = (samples + 1.) * 127.5
                    print("ok")

                    mu_gen, sigma_gen = fid.calculate_activation_statistics(
                        samples,
                        self.sess,
                        batch_size=self.fid_batch_size,
                        verbose=self.fid_verbose)

                    print("calculate FID:", end=" ", flush=True)
                    try:
                        FID = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_real, sigma_real)
                    except Exception as e:
                        print(e)
                        FID = 500

                    print(FID)

                    # Update event log with FID
                    self.sess.run(tf.assign(self.fid, FID))
                    summary_str = self.sess.run(self.fid_sum)
                    self.writer.add_summary(summary_str, counter)

                # Save checkpoint
                if (counter != 0) and (np.mod(counter, 2000) == 0):
                    self.save(config.checkpoint_dir, counter)

                counter += 1
Exemple #26
0
print("Check for inception model..", end=" ", flush=True)
inception_path = fid.check_or_download_inception(
    inception_path)  # download inception if necessary
print("OK")

# loads all images into memory (this might require a lot of RAM!)
print("Load images..", end=" ", flush=True)
(x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

print("%d/%d images found and loaded" % (len(x_train), len(x_test)))

print("Create inception graph..", end=" ", flush=True)
fid.create_inception_graph(
    inception_path)  # load the graph into the current TF graph
print("OK")

print("Calculte FID stats..", end=" ", flush=True)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mu_train, sigma_train = fid.calculate_activation_statistics(x_train,
                                                                sess,
                                                                batch_size=100)
    mu_test, sigma_test = fid.calculate_activation_statistics(x_test,
                                                              sess,
                                                              batch_size=100)
    np.savez_compressed(out_path_train, mu=mu_train, sigma=sigma_train)
    np.savez_compressed(out_path_test, mu=mu_test, sigma=sigma_test)
print("Finished")
Exemple #27
0
if args.mode == "pre-calculate":
    print("load images..")
    image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
    images = np.array(
        [imread(image).astype(np.float32) for image in image_list])
    print("%d images found and loaded" % len(images))

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(inception_path)
    print("ok")

    print("calculate FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics(images,
                                                        sess,
                                                        batch_size=100)
        np.savez_compressed(args.stats_path, mu=mu, sigma=sigma)
    print("finished")
else:
    image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
    images = np.array(
        [imread(str(fn)).astype(np.float32) for fn in image_list])

    f = np.load(args.stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(inception_path)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
Exemple #28
0
  def train(self, config):
    """Train DCGAN"""

    assert len(self.paths) > 0, 'no data loaded, was model not built?'
    print("load train stats.. ", end="", flush=True)
    # load precalculated training set statistics
    f = np.load(self.stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    print("ok")

    if self.load_checkpoint:
      if self.load(self.checkpoint_dir):
        print(" [*] Load SUCCESS")
      else:
        print(" [!] Load failed...")

    # Batch preparing
    batch_nums = min(len(self.paths), config.train_size) // config.batch_size

    counter = self.counter_start
    errD_fake = 0.
    errD_real = 0.
    errG = 0.
    errG_count = 0
    penD_gradient = 0.
    penD_lipschitz = 0.
    esti_slope = 0.
    lipschitz_estimate = 0.

    start_time = time.time()

    try:
      # Loop over epochs
      for epoch in range(config.epoch):

        # Assign learning rates for d and g
        lrate =  config.learning_rate_d # * (config.lr_decay_rate_d ** epoch)
        self.sess.run(tf.assign(self.learning_rate_d, lrate))
        lrate =  config.learning_rate_g # * (config.lr_decay_rate_g ** epoch)
        self.sess.run(tf.assign(self.learning_rate_g, lrate))

        # Loop over batches
        for batch_idx in range(batch_nums):
          # Update D network
          _, errD_fake_, errD_real_, summary_str, penD_gradient_, penD_lipschitz_, esti_slope_, lipschitz_estimate_ = self.sess.run(
              [self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum,
              self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss, self.d_mean_slope_target, self.d_lipschitz_estimate])
          for i in range(self.num_discriminator_updates - 1):
            self.sess.run([self.d_optim, self.d_loss_fake, self.d_loss_real, self.d_sum,
                           self.d_gradient_penalty_loss, self.d_lipschitz_penalty_loss])
          if np.mod(counter, 20) == 0:
            self.writer.add_summary(summary_str, counter)

          # Update G network
          if config.learning_rate_g > 0.: # and (np.mod(counter, 100) == 0 or lipschitz_estimate_ > 1 / (20 * self.lipschitz_penalty)):
            _, errG_, summary_str = self.sess.run([self.g_optim, self.g_loss, self.g_sum])
            if np.mod(counter, 20) == 0:
              self.writer.add_summary(summary_str, counter)
            errG += errG_
            errG_count += 1

          errD_fake += errD_fake_
          errD_real += errD_real_
          penD_gradient += penD_gradient_
          penD_lipschitz += penD_lipschitz_
          esti_slope += esti_slope_
          lipschitz_estimate += lipschitz_estimate_

          # Print
          if np.mod(counter, 100) == 0:
            print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, lip_pen: %.8f, gradient_pen: %.8f, g_loss: %.8f, d_tgt_slope: %.6f, d_avg_lip: %.6f, g_updates: %3d"  \
              % (epoch, batch_idx, batch_nums, time.time() - start_time, (errD_fake+errD_real) / 100.,
                 penD_lipschitz / 100., penD_gradient / 100., errG / 100., esti_slope / 100., lipschitz_estimate / 100., errG_count))
            errD_fake = 0.
            errD_real = 0.
            errG = 0.
            errG_count = 0
            penD_gradient = 0.
            penD_lipschitz = 0.
            esti_slope = 0.
            lipschitz_estimate = 0.

          # Save generated samples and FID
          if np.mod(counter, config.fid_eval_steps) == 0:

            # Save
            try:
              samples, d_loss, g_loss = self.sess.run(
                  [self.sampler, self.d_loss, self.g_loss])
              save_images(samples, [8, 8], '{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, batch_idx))
              print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
            except Exception as e:
              print(e)
              print("sample image error!")

            # FID
            print("samples for incept", end="", flush=True)

            samples = np.zeros((self.fid_n_samples, self.output_height, self.output_width, 3))
            n_batches = self.fid_n_samples // self.fid_sample_batchsize
            lo = 0
            for btch in range(n_batches):
              print("\rsamples for incept %d/%d" % (btch + 1, n_batches), end=" ", flush=True)
              samples[lo:(lo+self.fid_sample_batchsize)] = self.sess.run(self.sampler_fid)
              lo += self.fid_sample_batchsize

            samples = (samples + 1.) * 127.5
            print("ok")

            mu_gen, sigma_gen = fid.calculate_activation_statistics(samples,
                                                             self.sess,
                                                             batch_size=self.fid_batch_size,
                                                             verbose=self.fid_verbose)

            print("calculate FID:", end=" ", flush=True)
            try:
                FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
            except Exception as e:
                print(e)
                FID=500

            print(FID)

            # Update event log with FID
            self.sess.run(tf.assign(self.fid, FID))
            summary_str = self.sess.run(self.fid_sum)
            self.writer.add_summary(summary_str, counter)

          # Save checkpoint
          if (counter != 0) and (np.mod(counter, 2000) == 0):
            self.save(config.checkpoint_dir, counter)

          counter += 1
    except KeyboardInterrupt as e:
      self.save(config.checkpoint_dir, counter)
    except Exception as e:
      print(e)
    finally:
      # When done, ask the threads to stop.
      self.coord.request_stop()
      self.coord.join(self.threads)
Exemple #29
0
 def compute(images):
     m, s = fid.calculate_activation_statistics(
         np.array(images), inception_sess, args.batch_size, verbose=True)
     return fid.calculate_frechet_distance(m, s, mu0, sig0)
    def calculate_fid(self):
        import fid, pickle
        import tensorflow as tf

        stats_path = "fid_stats_cifar10_train.npz"  # training set statistics
        inception_path = fid.check_or_download_inception(
            "./tmp/"
        )  # download inception network

        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        fids = {}
        for ckpt in tqdm.tqdm(
            range(
                self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000
            ),
            desc="processing ckpt",
        ):
            states = torch.load(
                os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"),
                map_location=self.config.device,
            )

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            num_iters = (
                self.config.fast_fid.num_samples // self.config.fast_fid.batch_size
            )
            output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt))
            os.makedirs(output_path, exist_ok=True)
            for i in range(num_iters):
                init_samples = torch.rand(
                    self.config.fast_fid.batch_size,
                    self.config.data.channels,
                    self.config.data.image_size,
                    self.config.data.image_size,
                    device=self.config.device,
                )
                init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(
                    init_samples,
                    score,
                    sigmas,
                    self.config.fast_fid.n_steps_each,
                    self.config.fast_fid.step_lr,
                    verbose=self.config.fast_fid.verbose,
                )

                final_samples = all_samples[-1]
                for id, sample in enumerate(final_samples):
                    sample = sample.view(
                        self.config.data.channels,
                        self.config.data.image_size,
                        self.config.data.image_size,
                    )

                    sample = inverse_data_transform(self.config, sample)

                    save_image(
                        sample, os.path.join(output_path, "sample_{}.png".format(id))
                    )

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f["mu"][:], f["sigma"][:]
            f.close()

            fid.create_inception_graph(
                inception_path
            )  # load the graph into the current TF graph
            final_samples = (
                (final_samples - final_samples.min())
                / (final_samples.max() - final_samples.min()).data.cpu().numpy()
                * 255
            )
            final_samples = np.transpose(final_samples, [0, 2, 3, 1])
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    final_samples, sess, batch_size=100
                )

            fid_value = fid.calculate_frechet_distance(
                mu_gen, sigma_gen, mu_real, sigma_real
            )
            print("FID: %s" % fid_value)

        with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle:
            pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)