Exemplo n.º 1
0
def main():
    # Paths
    pic_path = os.path.join('./out/c/', 'checkpoints',
                            'dummy.samples%d.npy' % (NUM_POINTS))
    image_path = 'results_celeba/generated'  # set path to some generated images
    stats_path = 'fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    #image_list = glob.glob(os.path.join(image_path, '*.png'))
    #images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
    images = np.load(pic_path)

    images_t = images / 2.0 + 0.5
    images_t = 255.0 * images_t

    from PIL import Image
    img = Image.fromarray(np.uint8(images_t[0]), 'RGB')
    img.save('my.png')

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)
Exemplo n.º 2
0
def main():
    data = get_celeb_a(random_flip=False)[0]
    print()
    i = 0
    dir_celeb_a = './statistics/celeb_a_images'
    if not os.path.exists(dir_celeb_a):
        os.makedirs(dir_celeb_a)
    for batch in data:
        for image in batch:
            i += 1
            save_image(image, dir_celeb_a + '/{}'.format(i))
        if i % 10000 == 0:
            print('{}/167000 something'.format(i))

    tf.disable_v2_behavior()
    inception_path = fid.check_or_download_inception(None)
    fid.create_inception_graph(inception_path)
    files = [
        dir_celeb_a + '/' + filename for filename in os.listdir(dir_celeb_a)
    ]
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics_from_files(
            files, sess, verbose=True)
        np.savez('./statistics/fid_stats_celeb_a_train.npz',
                 mu=mu,
                 sigma=sigma)
Exemplo n.º 3
0
def main():
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)

    data_files = glob.glob(os.path.join("./img_align_celeba", "*.jpg"))
    data_files = sorted(data_files)[:10000]
    data_files = np.array(data_files)
    images = np.array([get_image(data_file, 148)
                       for data_file in data_files]).astype(np.float32)
    images = images * 255

    output_name = 'fid_stats_face'

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    print("ok")

    print("calculte FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma = fid.calculate_activation_statistics(images,
                                                        sess,
                                                        batch_size=100)
        np.savez_compressed(output_name, mu=mu, sigma=sigma)
    print("finished")
Exemplo n.º 4
0
 def fid_ms_for_imgs(images, mem_fraction=0.5):
     gpu_options = tf.GPUOptions(
         per_process_gpu_memory_fraction=mem_fraction)
     inception_path = fid.check_or_download_inception(None)
     fid.create_inception_graph(
         inception_path)  # load the graph into the current TF graph
     with tf.Session(config=tf.ConfigProto(
             gpu_options=gpu_options)) as sess:
         sess.run(tf.global_variables_initializer())
         mu_gen, sigma_gen = fid.calculate_activation_statistics(
             images, sess, batch_size=100)
     return mu_gen, sigma_gen
Exemplo n.º 5
0
def main(model, data_source, noise_method, noise_factors, lambdas):
    """
    model: RVAE or VAE
    data_source: data set of training. Either 'MNIST' or 'FASHION'
    noise_method: method of adding noise. Either 'sp' (represents salt-and-pepper) 
                  or 'gs' (represents Gaussian)
    noise_factors: noise factors
    lambdas: lambda
    """
    
    input_path = "../output/"+model+"_"+data_source+"_"+noise_method+"/"
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary
    print("ok")
    
    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" " , flush=True)
    
    output_path = "fid_precalc/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    output_path = output_path+model+"_"+data_source+"_"+noise_method+"/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    
    for l in lambdas:
        for nr in noise_factors:
            if model == 'RVAE':
                data_path = input_path+'lambda_'+str(l)+'/noise_'+str(nr)+'/generation_fid.npy'
                output_name = 'fid_stats_lambda_'+str(l)+'noise_'+str(nr)
            else:
                data_path = input_path+str(nr)+'/generation_fid.npy'
                output_name = 'fid_stats_noise_'+str(nr)
            images = np.load(data_path)[:10000]
            images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1)
            
            print("create inception graph..", end=" ", flush=True)
            fid.create_inception_graph(inception_path)  # load the graph into the current TF graph
            print("ok")
            
            print("calculte FID stats..", end=" ", flush=True)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100)
                np.savez_compressed(output_path+output_name, mu=mu, sigma=sigma)
            print("finished")
Exemplo n.º 6
0
def load_fid(args):
    act_stats = np.load(CIFAR_STATS_PATH)
    mu0, sig0 = act_stats['mu'], act_stats['sigma']
    inception_path = fid.check_or_download_inception(INCEPTION_PATH)
    inception_graph = tf.Graph()
    with inception_graph.as_default():
        fid.create_inception_graph(str(inception_path))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    inception_sess = tf.Session(config=config, graph=inception_graph)
    def compute(images):
        m, s = fid.calculate_activation_statistics(
            np.array(images), inception_sess, args.batch_size, verbose=True)
        return fid.calculate_frechet_distance(m, s, mu0, sig0)

    return compute, locals()
Exemplo n.º 7
0
def generate(args):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    if not os.path.exists(CIFAR_STATS_PATH):
        print('Generating FID statistics for test set...')
        print('Building Inception graph')
        with tf.Session(config=config) as sess:
            inception_path = fid.check_or_download_inception(INCEPTION_PATH)
            fid.create_inception_graph(str(inception_path))
            ds = datasets.load_cifar10(True)
            all_test_set = (ds.test.images + 1) * 128
            print(all_test_set.shape)
            m, s = fid.calculate_activation_statistics(
                all_test_set, sess, args.batch_size, verbose=True)
        np.savez(CIFAR_STATS_PATH, mu=m, sigma=s)
        print('Done')

    root_dir = os.path.dirname(args.dir)
    args_json = json.load(open(os.path.join(root_dir, 'hps.txt')))
    ckpt_dir = args.dir
    vars(args).update(args_json)

    model_graph = tf.Graph()
    with model_graph.as_default():
        x_ph, is_training_ph, model, optimizer, batch_size_sym, z_sample_sym, x_sample_sym = build_graph(args)
        saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=3, max_to_keep=6)

    model_sess = tf.Session(config=config, graph=model_graph)
    print('RESTORING MODEL FROM', ckpt_dir)
    saver.restore(model_sess, ckpt_dir)
    compute_fid, _ = load_fid(args)
    images = []
    for j in range(100):
        x_samples = model_sess.run(x_sample_sym, {batch_size_sym: 100, is_training_ph: False})
        x_samples = (np.clip(x_samples, -1, 1) + 1) / 2 * 256
        images.extend(x_samples)

    fscore = compute_fid(images)
    print('FID score = {}'.format(fscore))
    
    dest = os.path.join(root_dir, 'generated')
    if not os.path.exists(dest):
        os.makedirs(dest)
    for j, im in enumerate(images):
        plt.imsave(os.path.join(dest, '{}.png'.format(j)), im/256)
Exemplo n.º 8
0
def main(model, noise_factors, lambdas):
    """
    model: RVAE or VAE
    noise_factors: noise factors
    lambdas: lambda
    """

    input_path = model
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)

    output_path = "fid_precalc/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)

    for l in lambdas:
        for nr in noise_factors:
            data_path = input_path + 'lambda_' + str(l) + '/noise_' + str(
                nr) + '/generation_fid.npy'
            output_name = 'fid_stats_lambda_' + str(l) + 'noise_' + str(nr)
            images = np.load(data_path)
            images = np.transpose(images * 255, (0, 2, 3, 1))
            #images = np.stack((((images*255)).reshape(-1,28,28),)*3,axis=-1)

            print("create inception graph..", end=" ", flush=True)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            print("ok")

            print("calculte FID stats..", end=" ", flush=True)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images,
                                                                sess,
                                                                batch_size=100)
                np.savez_compressed(output_path + output_name,
                                    mu=mu,
                                    sigma=sigma)
            print("finished")
Exemplo n.º 9
0
def load_fid(dtest, args):
    import fid

    def transform_for_fid(im):
        assert len(im.shape) == 4 and im.dtype == np.float32
        if im.shape[-1] == 1:
            assert im.shape[-2] == 28
            im = np.tile(im, [1, 1, 1, 3])
        if not (im.std() < 1. and im.min() > -1.):
            print('WARNING: abnormal image range', im.std(), im.min())
        return (im + 1) * 128

    inception_path = fid.check_or_download_inception(INCEPTION_PATH)
    inception_graph = tf.Graph()
    with inception_graph.as_default():
        fid.create_inception_graph(str(inception_path))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    inception_sess = tf.Session(config=config, graph=inception_graph)

    stats_path = os.path.join(INCEPTION_PATH, f'{args.dataset}-stats.npz')
    if not os.path.exists(stats_path):
        mu0, sig0 = fid.calculate_activation_statistics(
            transform_for_fid(dtest),
            inception_sess,
            args.batch_size,
            verbose=True)
        np.savez(stats_path, mu0=mu0, sig0=sig0)
    else:
        sdict = np.load(stats_path)
        mu0, sig0 = sdict['mu0'], sdict['sig0']

    def compute(images):
        m, s = fid.calculate_activation_statistics(transform_for_fid(images),
                                                   inception_sess,
                                                   args.batch_size,
                                                   verbose=True)
        return fid.calculate_frechet_distance(m, s, mu0, sig0)

    return compute, locals()
Exemplo n.º 10
0
Arquivo: eval.py Projeto: ok1zjf/LBAE
def precalc(data_path, output_path):
    print("CALCULATING THE GT STATS....")
    # data_path = 'reconstructed_test/eval' # set path to training set images
    # output_path = data_path+'/fid_stats.npz' # path for where to store the statistics
    # if you have downloaded and extracted
    #   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
    # set this path to the directory where the extracted files are, otherwise
    # just set it to None and the script will later download the files for you
    inception_path = None
    print("check for inception model..", end=" ", flush=True)
    inception_path = fid.check_or_download_inception(
        inception_path)  # download inception if necessary
    print("ok")

    # loads all images into memory (this might require a lot of RAM!)
    print("load images..", end=" ", flush=True)
    image_list = glob.glob(os.path.join(data_path, '*.jpg'))
    if len(image_list) == 0:
        print("No images in directory ", data_path)
        return

    images = np.array([
        imageio.imread(str(fn), as_gray=False,
                       pilmode="RGB").astype(np.float32) for fn in image_list
    ])
    print("%d images found and loaded" % len(images))

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    print("ok")

    print("calculte FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu, sigma, acts = fid.calculate_activation_statistics(
            images, sess, batch_size=BATCH_SIZE)
        np.savez_compressed(output_path, mu=mu, sigma=sigma, activations=acts)
    print("finished")
# path where precalculated statistics for fid are to be saved
STATS_PATH = "C:/Users/andre/jupyter_ws/ganMetrics/FID/saved_statistics"

from enum import Enum


####
# enum of fib-statistics to be precalculated.
####
class SavedStatistics(Enum):
    WMN_EASY = "wmn_easy.pickle"
    WMN_DIFFICULT = "wmn_difficult.pickle"


# init inceptionmodel
inceptionPath = fid.check_or_download_inception(None)
fid.create_inception_graph(inceptionPath)

default_batchsize = 100


###
# whether a fid-statistic defined in SavedStatistics-enum is already calculated
###
def stats_exist(statfile):
    path = Path(STATS_PATH) / statfile.value
    return path.exists()


###
# given an array of images, calculates fid-stats == (mu, sigma) used for calculating fid
Exemplo n.º 12
0
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
import os
import glob
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import fid
from scipy.misc import imread
import tensorflow as tf

# Paths
image_path = '/home/minje/dev/dataset/cifar/cifar-10-fake' # set path to some generated images
#image_path = '/home/minje/dev/dataset/stl/fake-images' # set path to some generated images
stats_path = '/home/minje/dev/dataset/cifar/fid_stats_cifar10.npz' # training set statistics (maybe pre-calculated)
#stats_path = '/home/minje/dev/dataset/stl/fid_stats_stl10.npz' # training set statistics (maybe pre-calculated)
inception_path = fid.check_or_download_inception(None) # download inception network

# precalculate training set statistics
# #image_files = glob.glob(os.path.join('/home/minje/dev/dataset/cifar/cifar-10-images', '*.jpg'))
# image_files = glob.glob(os.path.join('/home/minje/dev/dataset/stl/images', '*.jpg'))
# fid.create_inception_graph(inception_path)
# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     mu_real, sigma_real = fid.calculate_activation_statistics_from_files(image_files, sess,
#         batch_size=100, verbose=True)
# np.savez(stats_path, mu=mu_real, sigma=sigma_real)
# exit(0)

# loads all images into memory (this might require a lot of RAM!)
image_files = glob.glob(os.path.join(image_path, '*.jpg'))
images = np.array([imread(str(fn)).astype(np.float32) for fn in image_files])
Exemplo n.º 13
0
model_name = args.model_name
epoch = args.epoch
batch_size = args.batch_size
lr = args.lr
use_bn = args.use_bn
z_dim = args.z_dim
init_steps = args.init_steps
zn_rec_coeff = args.zn_rec_coeff
zh_rec_coeff = args.zh_rec_coeff
nll_coeff = args.nll_coeff
experiment_name = args.experiment_name

pylib.mkdir('./output/%s' % experiment_name)
with open('./output/%s/setting.txt' % experiment_name, 'w') as f:
    f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
inception_path = fid.check_or_download_inception('../data/inception_model/')
fid_stats_dict = {
    'mnist': '../data/fid/fid_stats_mnist_train.npz',
    'cifar10': '../data/fid/fid_stats_cifar10_train.npz',
    'celeba': '../data/fid/fid_stats_celeba.npz'
}
fid_stats_path = fid_stats_dict[
    dataset_name] if dataset_name in fid_stats_dict else None

# dataset and models
Dataset, img_shape, get_imgs = utils.get_dataset(dataset_name)
dataset = Dataset(batch_size=batch_size)
# TODO: use a separate validation set
dataset_val = Dataset(batch_size=100)
Enc, Dec = utils.get_models(model_name)
Enc = partial(Enc, z_dim=z_dim, use_bn=use_bn)
Exemplo n.º 14
0
def run(dataset,
        generator_type,
        discriminator_type,
        latentsize,
        kernel_dimension,
        epsilon,
        learning_rate,
        batch_size,
        options,
        logdir_base='/tmp'):
    if dataset in ['billion_word']:
        dataset_type = 'text'
    else:
        dataset_type = 'image'
    tf.reset_default_graph()
    dtype = tf.float32

    run_name = '_'.join([
        '%s' % get_timestamp(),
        'g%s' % generator_type,
        'd%s' % discriminator_type,
        'z%d' % latentsize,
        'l%1.0e' % learning_rate,
        'l2p%1.0e' % options.l2_penalty,
        #'d%d' % kernel_dimension,
        #'eps%3.2f' % epsilon,
        'lds%1.e' % options.discriminator_lr_scale,
    ])
    run_name += ("_l2pscale%1.e" %
                 options.gen_l2p_scale) if options.gen_l2p_scale != 1.0 else ''
    run_name += "_M" if options.remember_previous else ''
    run_name += ("_dl%s" %
                 options.disc_loss) if options.disc_loss != 'l2' else ''
    run_name += ("_%s" %
                 options.logdir_suffix) if options.logdir_suffix else ''
    run_name = run_name.replace('+', '')

    if options.verbosity == 0:
        tf.logging.set_verbosity(tf.logging.ERROR)

    subdir = "%s_%s" % (get_timestamp('%y%m%d'), dataset)
    logdir = Path(logdir_base) / subdir / run_name
    print_info("\nLogdir: %s\n" % logdir, options.verbosity > 0)
    if __name__ == "__main__" and options.sample_images is None:
        startup_bookkeeping(logdir, __file__)
        trainlog = open(str(logdir / 'logfile.csv'), 'w')
    else:
        trainlog = None

    dataset_pattern, n_samples, img_shape = get_dataset_path(dataset)
    z = tf.random_normal([batch_size, latentsize], dtype=dtype, name="z")
    if dataset_type == 'text':
        n_samples = options.num_examples
        y, lines_as_ints, charmap, inv_charmap = load_text_dataset(
            dataset_pattern,
            batch_size,
            options.sequence_length,
            options.num_examples,
            options.max_vocab_size,
            shuffle=True,
            num_epochs=None)
        img_shape = [options.sequence_length, len(charmap)]
        true_ngram_model = ngram_language_model.NgramLanguageModel(
            lines_as_ints, options.ngrams, len(charmap))
    else:
        y = load_image_dataset(dataset_pattern,
                               batch_size,
                               img_shape,
                               n_threads=options.threads)

    x = create_generator(z, img_shape,
                         options.l2_penalty * options.gen_l2p_scale,
                         generator_type, batch_size)
    assert x.get_shape().as_list()[1:] == y.get_shape().as_list(
    )[1:], "X and Y have different shapes: %s vs %s" % (
        x.get_shape().as_list(), y.get_shape().as_list())

    disc_x = create_discriminator(x, discriminator_type, options.l2_penalty,
                                  False)
    disc_y = create_discriminator(y, discriminator_type, options.l2_penalty,
                                  True)

    with tf.name_scope('loss'):
        disc_x = tf.reshape(disc_x, [-1])
        disc_y = tf.reshape(disc_y, [-1])
        pot_x, pot_y = get_potentials(x, y, kernel_dimension, epsilon)

        if options.disc_loss == 'l2':
            disc_loss_fn = tf.losses.mean_squared_error
        elif options.disc_loss == 'l1':
            disc_loss_fn = tf.losses.absolute_difference
        else:
            assert False, "Unknown Discriminator Loss: %s" % options.disc_loss

        loss_d_x = disc_loss_fn(pot_x, disc_x)
        loss_d_y = disc_loss_fn(pot_y, disc_y)
        loss_d = loss_d_x + loss_d_y
        loss_g = tf.reduce_mean(disc_x)

        if options.remember_previous:
            x_old = tf.get_variable("x_old",
                                    shape=x.shape,
                                    initializer=tf.zeros_initializer(),
                                    trainable=False)
            disc_x_old = create_discriminator(x_old, discriminator_type,
                                              options.l2_penalty, True)
            disc_x_old = tf.reshape(disc_x_old, [-1])
            pot_x_old = calculate_potential(x, y, x_old, kernel_dimension,
                                            epsilon)
            loss_d_x_old = disc_loss_fn(pot_x_old, disc_x_old)
            loss_d += loss_d_x_old

    vars_d = [
        v for v in tf.global_variables() if v.name.startswith('discriminator')
    ]
    vars_g = [
        v for v in tf.global_variables() if v.name.startswith('generator')
    ]
    optim_d = tf.train.AdamOptimizer(learning_rate *
                                     options.discriminator_lr_scale,
                                     beta1=options.discriminator_beta1,
                                     beta2=options.discriminator_beta2)
    optim_g = tf.train.AdamOptimizer(learning_rate,
                                     beta1=options.generator_beta1,
                                     beta2=options.generator_beta2)

    # we can sum all regularizers in one term, the var-list argument to minimize
    # should make sure each optimizer only regularizes "its own" variables
    regularizers = tf.reduce_sum(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    train_op_d = optim_d.minimize(loss_d + regularizers, var_list=vars_d)
    train_op_g = optim_g.minimize(loss_g + regularizers, var_list=vars_g)
    train_op = tf.group(train_op_d, train_op_g)

    if options.remember_previous:
        with tf.control_dependencies([train_op]):
            assign_x_op = tf.assign(x_old, x)
        train_op = tf.group(train_op, assign_x_op)

    # Tensorboard summaries
    if dataset_type == 'image':
        x_img = (tf.clip_by_value(x, -1.0, 1.0) + 1) / 2.0
        y_img = tf.clip_by_value((y + 1) / 2, 0.0, 1.0)
    with tf.name_scope('potential'):
        tf.summary.histogram('x', pot_x)
        tf.summary.histogram('y', pot_y)
        if options.remember_previous:
            tf.summary.histogram('x_old', pot_x_old)
    if options.create_summaries:
        if dataset_type == 'image':
            with tf.name_scope("distances"):
                tf.summary.histogram("xx", generate_all_distances(x, x))
                tf.summary.histogram("xy", generate_all_distances(x, y))
                tf.summary.histogram("yy", generate_all_distances(y, y))
        with tf.name_scope('discriminator_stats'):
            tf.summary.histogram('output_x', disc_x)
            tf.summary.histogram('output_y', disc_y)
            tf.summary.histogram('pred_error_y', pot_y - disc_y)
            tf.summary.histogram('pred_error_x', pot_x - disc_x)
        if dataset_type == 'image':
            img_smry = tf.summary.image("out_img", x_img, 2)
            img_smry = tf.summary.image("in_img", y_img, 2)
        with tf.name_scope("losses"):
            tf.summary.scalar('loss_d_x', loss_d_x)
            tf.summary.scalar('loss_d_y', loss_d_y)
            tf.summary.scalar('loss_d', loss_d)
            tf.summary.scalar('loss_g', loss_g)

        with tf.name_scope('weightnorm'):
            for v in tf.global_variables():
                if not v.name.endswith('kernel:0'):
                    continue
                tf.summary.scalar("wn_" + v.name[:-8], tf.norm(v))
        with tf.name_scope('mean_activations'):
            for op in tf.get_default_graph().get_operations():
                if not op.name.endswith('Tanh'):
                    continue
                tf.summary.scalar("act_" + op.name,
                                  tf.reduce_mean(op.outputs[0]))
    merged_smry = tf.summary.merge_all()

    if dataset_type == 'image':
        fid_stats_file = options.fid_stats % dataset.lower()
        assert Path(fid_stats_file).exists(
        ), "Can't find training set statistics for FID (%s)" % fid_stats_file
        f = np.load(fid_stats_file)
        mu_fid, sigma_fid = f['mu'][:], f['sigma'][:]
        f.close()
        inception_path = fid.check_or_download_inception(
            options.inception_path)
        fid.create_inception_graph(inception_path)

    maxv = 0.05
    cmap = plt.cm.ScalarMappable(mpl.colors.Normalize(-maxv, maxv),
                                 cmap=plt.cm.RdBu)
    config = tf.ConfigProto(intra_op_parallelism_threads=2,
                            inter_op_parallelism_threads=2,
                            use_per_session_threads=True,
                            gpu_options=tf.GPUOptions(allow_growth=True))

    save_vars = [
        v for v in tf.global_variables() if v.name.startswith('generator')
    ]
    save_vars += [
        v for v in tf.global_variables() if v.name.startswith('discriminator')
    ]

    with tf.Session(config=config) as sess:
        log = tf.summary.FileWriter(str(logdir), sess.graph)
        sess.run(tf.global_variables_initializer())
        if options.resume_checkpoint:
            loader = tf.train.Saver(save_vars)
            loader.restore(sess, options.resume_checkpoint)
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        fd = {}

        if options.sample_images is not None:
            n_batches = (1 + options.sample_images_size // batch_size)
            sample_images(x_img, sess, n_batches, path=options.sample_images)
            coord.request_stop()
            coord.join(threads)
            return

        saver = tf.train.Saver(save_vars, max_to_keep=50)
        max_iter = int(options.iterations * 1000)

        n_epochs = max_iter / (n_samples / batch_size)
        print_info(
            "total iterations: %d (= %3.2f epochs)" % (max_iter, n_epochs),
            options.verbosity > 0)
        t0 = time.time()

        try:
            for cur_iter in range(
                    max_iter + 1
            ):  # +1 so we are more likely to get a model/stats line at the end
                sess.run(train_op)
                if (cur_iter > 0) and (cur_iter % options.checkpoint_every
                                       == 0):
                    saver.save(sess,
                               str(logdir / 'model'),
                               global_step=cur_iter)

                if cur_iter % options.stats_every == 0:
                    if dataset_type == 'image':
                        smry, xx_img = sess.run([merged_smry, x_img])
                        log.add_summary(smry, cur_iter)
                        images = sample_images(
                            x_img, sess,
                            n_batches=5 * 1024 // batch_size) * 255
                        mu_gen, sigma_gen = fid.calculate_activation_statistics(
                            images, sess, batch_size=128)
                        quality_measure = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, mu_fid, sigma_fid)
                        fig = plot_tiles(xx_img,
                                         10,
                                         10,
                                         local_norm="none",
                                         figsize=(6.6, 6.6))
                        fig.savefig(str(logdir / ('%09d.png' % cur_iter)))
                        plt.close(fig)
                    elif dataset_type == 'text':
                        smry = sess.run(merged_smry)
                        # Note: to compare with WGAN-GP, we can only take 5 samples since our batch size is 2x theirs
                        # and JSD improves a lot with larger samples size
                        sample_text_ = sample_text(x, sess, 5, inv_charmap,
                                                   logdir / 'samples',
                                                   cur_iter)
                        gen_ngram_model = ngram_language_model.NgramLanguageModel(
                            sample_text_, options.ngrams, len(charmap))
                        js = []
                        for i in range(options.ngrams):
                            js.append(
                                true_ngram_model.js_with(
                                    gen_ngram_model, i + 1))
                            #print('js%d' % (i+1), quality_measure[i])
                        quality_measure = js[3] if options.ngrams < 6 else (
                            str(js[3]) + '/' + str(js[5]))

                    s = (cur_iter, quality_measure, time.time() - t0, dataset,
                         run_name)
                    print_info("%9d  %s -- %3.2fs %s %s" % s,
                               options.verbosity > 0)
                    if trainlog:
                        print(', '.join([str(ss) for ss in s]),
                              file=trainlog,
                              flush=True)
                    log.add_summary(smry, cur_iter)

        except KeyboardInterrupt:
            saver.save(sess, str(logdir / 'model'), global_step=cur_iter)
        finally:
            if trainlog:
                trainlog.close()
            coord.request_stop()
            coord.join(threads)
        return
Exemplo n.º 15
0
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
import os
import glob
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import fid
from scipy.misc import imread
import tensorflow as tf

# Paths
image_path = './score/'  # set path to some generated images
stats_path = './data/train/figs'  # training set statistics
inception_path = fid.check_or_download_inception(
    './tmp/inception-2015-12-05/')  # download inception network

# loads all images into memory (this might require a lot of RAM!)
image_list = glob.glob(os.path.join(image_path, '*.jpg'))
images = np.array([imread(str(fn)).astype(np.float32) for fn in files])

# load precalculated training set statistics
f = np.load(path)
mu_real, sigma_real = f['mu'][:], f['sigma'][:]
f.close()

fid.create_inception_graph(
    inception_path)  # load the graph into the current TF graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mu_gen, sigma_gen = fid.calculate_activation_statistics(images,
                                                            sess,
Exemplo n.º 16
0
    def __init__(self,
                 d_net,
                 g_net,
                 x_sampler,
                 z_sampler,
                 args,
                 inception,
                 log_dir,
                 scale=10.0):
        self.model = args.model
        self.data = args.data
        self.log_dir = log_dir
        self.g_net = g_net
        self.d_net = d_net
        self.x_sampler = x_sampler
        self.z_sampler = z_sampler
        self.x_dim = d_net.x_dim
        self.z_dim = g_net.z_dim
        self.beta = 0.9999
        self.d_iters = 1
        self.batch_size = 64
        self.inception = inception
        self.inception_path = fid.check_or_download_inception(
            './data/imagenet_model')

        if self.data == 'cifar10':
            self.stats_path = './data/fid_stats_cifar10_train.npz'
        elif self.data == 'stl10':
            self.stats_path = './data/fid_stats_stl10.npz'

        self.x = tf.placeholder(tf.float32, [None] + self.x_dim, name='x')
        self.z = tf.placeholder(tf.float32, [None] + [self.z_dim], name='z')

        self.x_ = self.g_net(self.z)

        self.d = self.d_net(self.x)
        self.d_ = self.d_net(self.x_, reuse=True)

        self.g_loss = tf.reduce_mean(self.d_)
        self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)

        epsilon = tf.random_uniform([], 0.0, 1.0)
        x_hat = epsilon * self.x + (1 - epsilon) * self.x_
        d_hat = self.d_net(x_hat, reuse=True)

        ddx = tf.gradients(d_hat, x_hat)[0]
        print(ddx.get_shape().as_list())
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
        self.gp_loss = tf.reduce_mean(tf.square(ddx - 1.0) * scale)

        self.d_loss_reg = self.d_loss + self.gp_loss

        self.d_adam, self.g_adam = None, None
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.0, beta2=0.9)\
                .minimize(self.d_loss_reg, var_list=self.d_net.vars)
            self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.0, beta2=0.9)\
                .minimize(self.g_loss, var_list=self.g_net.vars)

        for var_ in tf.model_variables('g_woa'):
            print(var_.name)

        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        with self.sess:
            fid.create_inception_graph(
                self.inception_path
            )  # load the graph into the current TF graph
Exemplo n.º 17
0
def fid_run(gan, run_name, train_params, g_model_params, d_model_params,
            summ_root, chk_root, stat_root, data_set, data_set_size):
    '''
    cGAN Training cycle

    * Runs training
    * Calculate FID measure
    * Write summaries for tensorboard
    * Save/restore model

    Args:
    'gan'               An instance of CGAN class
    'run_name'          Name for a run
    'train_params'      A dictionary of parameters for training
        ['lr']              Learning rate (float)
        ['beta1']           Beta1 parameter of AdamOptimizer (float)
        ['batch_size']      Mini-batch batch size (int)
        ['max_epoch']       Epoch limit (int)
    'g_model_params'    A dictionary of parameters for generator
        ['input_shape']        input shape of a single data. Usually a pair of shape (z shape, y shape)
        ['output_shape']       output shape of a single data
        ['output_image_shape'] output shape when interpreted as image
        ['summ_shape_per_class'] summary tile shape per class
        other keys depend on your cgan model function
    'd_model_params'    A dictionary of parameters for discriminator
        ['input_shape']        input shape of a single data. Usually a pair of shape (x shape, y shape)
        ['output_shape']       output shape of a single data. Usually a pair of shape (score shape, logit shape)
        other keys depend on your cgan model function
    'summ_root'         Path for summary
    'chk_root'          Path for checkpoints. checkpoints are saved every epoch
    'stat_root'   Path for activation statistics save
    'data_set'         dataset for training (tf.data.Dataset)
    'data_set_size'    data_set size

    Return:
    None
    '''
    '''Create Features'''
    placeholders = {}
    placeholders['z'] = tf.placeholder(tf.float32,
                                       shape=_shape_for_batch(
                                           g_model_params['input_shape'][0]),
                                       name='z')
    placeholders['y'] = tf.placeholder(tf.float32,
                                       shape=_shape_for_batch(
                                           g_model_params['input_shape'][1]),
                                       name='y')
    # y is shared between generator and discriminator
    assert (
        g_model_params['input_shape'][1] == d_model_params['input_shape'][1])
    placeholders['x'] = tf.placeholder(tf.float32,
                                       shape=_shape_for_batch(
                                           d_model_params['input_shape'][0]),
                                       name='x')
    placeholders['mode'] = tf.placeholder(dtype=tf.bool, name='mode')

    g_features = {'z': placeholders['z'], 'y': placeholders['y']}
    d_real_features = {'x': placeholders['x'], 'y': placeholders['y']}
    '''Create networks'''
    generator = gan.generator_fn(g_features, g_model_params,
                                 placeholders['mode'])
    real_discriminator, d_real_logit = gan.discriminator_fn(
        d_real_features, d_model_params, placeholders['mode'])
    d_fake_features = {'x': generator, 'y': placeholders['y']}
    fake_discriminator, d_fake_logit = gan.discriminator_fn(
        d_fake_features, d_model_params, placeholders['mode'], reuse=True)
    '''Define loss for optimization'''
    losses = {}
    d_real_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logit,
                                                labels=tf.ones(
                                                    tf.shape(d_real_logit))))
    d_fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logit,
                                                labels=tf.zeros(
                                                    tf.shape(d_fake_logit))))
    losses['d_loss'] = d_real_loss + d_fake_loss
    losses['g_loss'] = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logit,
                                                labels=tf.ones(
                                                    tf.shape(d_fake_logit))))
    '''Setup optimizer'''
    trainable_vars = tf.trainable_variables()
    d_vars = [
        var for var in trainable_vars if var.name.startswith('discriminator')
    ]
    g_vars = [
        var for var in trainable_vars if var.name.startswith('generator')
    ]

    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        d_optim = (tf.train.AdamOptimizer(
            train_params['lr'],
            beta1=train_params['beta1']).minimize(losses['d_loss'],
                                                  var_list=d_vars))
        g_optim = (tf.train.AdamOptimizer(
            train_params['lr'],
            beta1=train_params['beta1']).minimize(losses['g_loss'],
                                                  var_list=g_vars))
    '''Define summaries'''
    summ_g_loss = tf.summary.scalar('g_loss', losses['g_loss'])
    summ_d_loss = tf.summary.scalar('d_loss', losses['d_loss'])
    summ_merged = tf.summary.merge_all()

    tile_image = util.make_tile_image(
        generator=generator,
        image_shape=g_model_params['output_image_shape'],
        input_shape=g_model_params['input_shape'],
        shape_per_class=g_model_params['summ_shape_per_class'])
    summ_image = tf.summary.image('generator', tile_image, max_outputs=1)
    '''Setup InceptionNet'''
    inception_graph = tf.Graph()
    inception_path = fid.check_or_download_inception(None)
    with inception_graph.as_default():
        fid.create_inception_graph(inception_path)
    inception_sess = tf.Session(graph=inception_graph)

    epoch_var = tf.get_variable('epoch', shape=(), dtype=tf.int32)
    '''Run training'''
    with tf.Session() as (sess), tf.summary.FileWriter(
            os.path.join(summ_root, run_name),
            sess.graph) as train_summ_writer:
        '''Load validation/test data'''
        print('loading data...')
        all_set = (data_set.batch(data_set_size))
        (data_all,
         all_label) = sess.run(all_set.make_one_shot_iterator().get_next())
        print('data loading done')
        print('calculating InceptionNet activations...')
        mu, sigma = _get_statistics(stat_root, data_all,
                                    g_model_params['output_image_shape'],
                                    inception_sess)
        print('activation calculation done')
        '''Initialize misc training variables'''
        prefix_tag = run_name
        chk_path = os.path.join(chk_root, run_name, prefix_tag + "model.ckpt")
        chk_saver = tf.train.Saver()
        initial_epoch = 0
        if os.path.exists(chk_path + ".index"):
            print('Resuming from previous checkpoint {}'.format(chk_path))
            chk_saver.restore(sess, chk_path)
            initial_epoch = sess.run(epoch_var) + 1
        else:
            tf.global_variables_initializer().run()

        for epoch in range(initial_epoch, train_params['max_epoch']):
            epoch_update_op = epoch_var.assign(epoch)
            sess.run(epoch_update_op)

            it = 0
            batch_size = train_params['batch_size']
            while it < data_all.shape[0]:
                data, label = data_all[it:it +
                                       train_params['batch_size']], all_label[
                                           it:it + train_params['batch_size']]
                it += train_params['batch_size']
                if it >= data_all.shape[0]:
                    batch_size = data_all.shape[0] % train_params['batch_size']
                    if batch_size == 0:
                        batch_size = train_params['batch_size']
                else:
                    batch_size = train_params['batch_size']
                '''Train discriminator'''
                d_loss = _run_discriminator_train(
                    sess=sess,
                    placeholders=placeholders,
                    loss=losses['d_loss'],
                    optim=d_optim,
                    x_data=data,
                    y_label=label,
                    batch_size=batch_size,
                    z_input_shape=g_model_params['input_shape'][0])
                '''Train generator'''
                g_loss = _run_generator_train(
                    sess=sess,
                    placeholders=placeholders,
                    loss=losses['g_loss'],
                    optim=g_optim,
                    x_data=data,
                    batch_size=batch_size,
                    z_input_shape=g_model_params['input_shape'][0],
                    y_input_shape=g_model_params['input_shape'][1])
            '''Training loss summary'''
            _run_summary(sess=sess,
                         placeholders=placeholders,
                         writer=train_summ_writer,
                         summ=summ_merged,
                         epoch=epoch,
                         x_data=data,
                         y_label=label,
                         batch_size=batch_size,
                         z_input_shape=g_model_params['input_shape'][0])
            '''Make image tile'''
            _run_tile_summary(
                sess=sess,
                placeholders=placeholders,
                epoch=epoch,
                writer=train_summ_writer,
                summ=summ_image,
                shape_per_class=g_model_params['summ_shape_per_class'],
                z_input_shape=g_model_params['input_shape'][0],
                y_input_shape=g_model_params['input_shape'][1])
            '''Calculate FID of a sample'''
            _log_scalar(
                'fid', epoch, train_summ_writer,
                _run_fid_calculation(
                    sess=sess,
                    inception_sess=inception_sess,
                    placeholders=placeholders,
                    batch_size=100,
                    iteration=1,
                    generator=generator,
                    mu=mu,
                    sigma=sigma,
                    epoch=epoch,
                    image_shape=g_model_params['output_image_shape'],
                    z_input_shape=g_model_params['input_shape'][0],
                    y_input_shape=g_model_params['input_shape'][1]))
            '''Save checkpoint'''
            chk_saver.save(sess, chk_path)
            print('epoch {} : d_loss : {} / g_loss : {}'.format(
                epoch, d_loss, g_loss))
        '''Finalize Loop'''
        '''Calculate FID of a sample'''
        final_fid = _run_fid_calculation(
            sess=sess,
            inception_sess=inception_sess,
            placeholders=placeholders,
            batch_size=100,
            iteration=10,
            generator=generator,
            mu=mu,
            sigma=sigma,
            epoch=train_params['max_epoch'] - 1,
            image_shape=g_model_params['output_image_shape'],
            z_input_shape=g_model_params['input_shape'][0],
            y_input_shape=g_model_params['input_shape'][1])
        print('final_fid: {}'.format(final_fid))
        _log_scalar('final_fid', train_params['max_epoch'] - 1,
                    train_summ_writer, final_fid)
        inception_sess.close()
Exemplo n.º 18
0
Arquivo: eval.py Projeto: ok1zjf/LBAE
def fid_imgs(cfg):
    print("CALCULATING FID/KID scores")
    rnd_seed = 12345
    random.seed(rnd_seed)
    np.random.seed(rnd_seed)
    tf.compat.v2.random.set_seed(rnd_seed)
    tf.random.set_random_seed(rnd_seed)
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    # load precalculated training set statistics
    print("Loading stats from:", cfg.stats_filename, '  ...', end='')
    f = np.load(cfg.stats_filename)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]

    activations_ref = None
    if 'activations' in f:
        activations_ref = f['activations']
        print(" reference activations #:", activations_ref.shape[0])

    f.close()
    print("done")

    fid_epoch = 0
    epoch_info_file = cfg.exp_path + '/fid-epoch.txt'
    if os.path.isfile(epoch_info_file):
        fid_epoch = open(epoch_info_file, 'rt').read()
    else:
        print("ERROR: couldnot find file:", epoch_info_file)

    best_fid_file = cfg.exp_path + '/fid-best.txt'
    best_fid = 1e10
    if os.path.isfile(best_fid_file):
        best_fid = float(open(best_fid_file, 'rt').read())
        print("Best FID: " + str(best_fid))

    pr = None
    pr_file = cfg.exp_path + '/pr.txt'
    if os.path.isfile(pr_file):
        pr = open(pr_file).read()
        print("PR: " + str(pr))

    rec = []
    rec.append(fid_epoch)
    rec.append('nref:' + str(activations_ref.shape[0]))

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    dirs = cfg.image_path.split(',')
    first_fid = None
    for dir in dirs:
        print("Working on:", dir)
        test_name = dir.split('/')[-1]
        rec.append(test_name)
        # loads all images into memory (this might require a lot of RAM!)
        image_list = glob.glob(os.path.join(dir, '*.jpg'))
        image_list = image_list + glob.glob(os.path.join(dir, '*.png'))
        image_list.sort()
        print("Loading images:", len(image_list), '  ...', end='')
        images = np.array([
            imageio.imread(str(fn), as_gray=False,
                           pilmode="RGB").astype(np.float32)
            for fn in image_list
        ])
        print("done")

        print("Extracting features ", end='')
        os.environ['CUDA_VISIBLE_DEVICES'] = '1'
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            mu_gen, sigma_gen, activations = fid.calculate_activation_statistics(
                images, sess, batch_size=BATCH_SIZE)
        print("Extracted activations:", activations.shape[0])
        rec.append('ntest:' + str(activations.shape[0]))

        if cfg.fid:
            # Calculate FID
            print("Calculating FID.....")
            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            rec.append('fid:' + str(fid_value))
            if first_fid is None:
                first_fid = fid_value

            if best_fid > first_fid and fid_epoch != 0:
                epoch = int(fid_epoch.split(' ')[0].split(':')[1])
                print("Storing best FID model. Epoch: " + str(epoch) +
                      "  Current FID: " + str(best_fid) + " new: " +
                      str(first_fid))
                best_fid = first_fid
                # Store best fid & weights
                with open(best_fid_file, 'wt') as f:
                    f.write(str(first_fid))
                model_file = cfg.exp_path + '/models/weights-' + str(
                    epoch) + '.cp'
                backup_model_file = cfg.exp_path + '/models/' + str(
                    epoch) + '.cp'
                os.system('cp ' + model_file + '  ' + backup_model_file)

        if cfg.kid:
            # Calculate KID
            # Parameters:
            print("Calculating KID...")
            mmd_degree = 3
            mmd_gamma = None
            mmd_coef0 = 1
            mmd_var = False
            mmd_subsets = 100
            mmd_subset_size = 1000

            ret = polynomial_mmd_averages(activations,
                                          activations_ref,
                                          degree=mmd_degree,
                                          gamma=mmd_gamma,
                                          coef0=mmd_coef0,
                                          ret_var=mmd_var,
                                          n_subsets=mmd_subsets,
                                          subset_size=mmd_subset_size)

            if mmd_var:
                mmd2s, vars = ret
            else:
                mmd2s = ret

            kid_value = mmd2s.mean()
            kid_value_std = mmd2s.std()
            rec.append('kid_mean:' + str(kid_value))
            rec.append('kid_std:' + str(kid_value_std))

        if cfg.psnr and test_name == 'reco':
            image_list = glob.glob(os.path.join(cfg.stats_path, '*.jpg'))
            image_list.sort()
            if len(image_list) == 0:
                print("No images in directory ", cfg.stats_path)
                return

            images_gt = np.array([
                imageio.imread(str(fn), as_gray=False,
                               pilmode="RGB").astype(np.float32)
                for fn in image_list
            ])
            print("%d images found and loaded" % len(images_gt))
            print("Calculating PSNR...")
            psnr_val = psnr(images_gt, images)
            print("Calculating SSIM...")
            ssim_val = ssim(images_gt, images)

            print('PSNR:', psnr_val, 'SSIM:', ssim_val)
            rec.append('psnr:' + str(psnr_val))
            rec.append('ssim:' + str(ssim_val))

        print(' '.join(rec))

    if pr is not None:
        rec.append(pr)

    print(' '.join(rec))

    # Write out results
    with open(cfg.exp_path + '/results.txt', 'a+') as f:
        f.write(' '.join(rec) + '\n')

    return first_fid
def begin_training(params):
    """
    Takes model name, Generator and Discriminator architectures as input,
    builds the rest of the graph.

    """
    model_name, Generator, Discriminator, epochs, restore = params
    fid_stats_file = "./tmp/"
    inception_path = "./tmp/"
    TRAIN_FOR_N_EPOCHS = epochs
    MODEL_NAME = model_name + "_" + FLAGS.dataset
    SUMMARY_DIR = 'summary/' + MODEL_NAME + "/"
    SAVE_DIR = "./saved_models/" + MODEL_NAME + "/"
    OUTPUT_DIR = './outputs/' + MODEL_NAME + "/"
    helpers.refresh_dirs(SUMMARY_DIR, OUTPUT_DIR, SAVE_DIR, restore)
    with tf.Graph().as_default():
        with tf.variable_scope('input'):
            all_real_data_conv = input_pipeline(
                train_data_list, batch_size=BATCH_SIZE)
            # Split data over multiple GPUs:
            split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES))
        global_step = tf.train.get_or_create_global_step()

        gen_cost, disc_cost, pre_real, pre_fake, gradient_penalty, real_data, fake_data, disc_fake, disc_real = split_and_setup_costs(
            Generator, Discriminator, split_real_data_conv)

        gen_train_op, disc_train_op, gen_learning_rate = setup_train_ops(
            gen_cost, disc_cost, global_step)

        performance_merged, distances_merged = add_summaries(gen_cost, disc_cost, fake_data, real_data,
                                                             gen_learning_rate, gradient_penalty, pre_real, pre_fake)

        saver = tf.train.Saver(max_to_keep=1)
        all_fixed_noise_samples = helpers.prepare_noise_samples(
            DEVICES, Generator)

        fid_stats_file += FLAGS.dataset + "_stats.npz"
        assert tf.gfile.Exists(
            fid_stats_file), "Can't find training set statistics for FID (%s)" % fid_stats_file
        f = np.load(fid_stats_file)
        mu_fid, sigma_fid = f['mu'][:], f['sigma'][:]
        f.close()
        inception_path = fid.check_or_download_inception(inception_path)
        fid.create_inception_graph(inception_path)

        # Create session
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        if FLAGS.use_XLA:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        with tf.Session(config=config) as sess:
            # Restore variables if required
            ckpt = tf.train.get_checkpoint_state(SAVE_DIR)
            if restore and ckpt and ckpt.model_checkpoint_path:
                print("Restoring variables...")
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Variables restored from:\n', ckpt.model_checkpoint_path)
            else:
                # Initialise all the variables
                print("Initialising variables")
                sess.run(tf.local_variables_initializer())
                sess.run(tf.global_variables_initializer())
                print('Variables initialised.')
            # Start input enqueue threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('Queue runners started.')
            real_im = sess.run([all_real_data_conv])[0][0][0][0:5]
            print("Real Image range sample: ", real_im)

            summary_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph)
            helpers.sample_dataset(sess, all_real_data_conv, OUTPUT_DIR)
            # Training loop
            try:
                ep_start = (global_step.eval(sess)) // EPOCH
                for epoch in tqdm(range(ep_start, TRAIN_FOR_N_EPOCHS), desc="Epochs passed"):
                    step = (global_step.eval(sess)) % EPOCH
                    for _ in tqdm(range(step, EPOCH), desc="Current epoch %i" % epoch, mininterval=0.5):
                        # train gen
                        _, step = sess.run([gen_train_op, global_step])
                        # Train discriminator
                        if (MODE == 'dcgan') or (MODE == 'lsgan'):
                            disc_iters = 1
                        else:
                            disc_iters = CRITIC_ITERS
                        for _ in range(disc_iters):
                            _disc_cost, _ = sess.run(
                                [disc_cost, disc_train_op])
                        if step % (128) == 0:
                            _, _, _, performance_summary, distances_summary = sess.run(
                                [gen_train_op, disc_cost, disc_train_op, performance_merged, distances_merged])
                            summary_writer.add_summary(
                                performance_summary, step)
                            summary_writer.add_summary(
                                distances_summary, step)

                        if step % (512) == 0:
                            saver.save(sess, SAVE_DIR, global_step=step)
                            helpers.generate_image(step, sess, OUTPUT_DIR,
                                                   all_fixed_noise_samples, Generator, summary_writer)
                            fid_score, IS_mean, IS_std, kid_score = fake_batch_stats(
                                sess, fake_data)
                            pre_real_out, pre_fake_out, fake_out, real_out = sess.run(
                                [pre_real, pre_fake, disc_fake, disc_real])
                            scalar_avg_fake = np.mean(fake_out)
                            scalar_sdev_fake = np.std(fake_out)
                            scalar_avg_real = np.mean(real_out)
                            scalar_sdev_real = np.std(real_out)

                            frechet_dist = frechet_distance(
                                pre_real_out, pre_fake_out)
                            kid_score = np.mean(kid_score)
                            inception_summary = tf.Summary()
                            inception_summary.value.add(
                                tag="distances/FD", simple_value=frechet_dist)
                            inception_summary.value.add(
                                tag="distances/FID", simple_value=fid_score)
                            inception_summary.value.add(
                                tag="distances/IS_mean", simple_value=IS_mean)
                            inception_summary.value.add(
                                tag="distances/IS_std", simple_value=IS_std)
                            inception_summary.value.add(
                                tag="distances/KID", simple_value=kid_score)
                            inception_summary.value.add(
                                tag="distances/scalar_mean_fake", simple_value=scalar_avg_fake)
                            inception_summary.value.add(
                                tag="distances/scalar_sdev_fake", simple_value=scalar_sdev_fake)
                            inception_summary.value.add(
                                tag="distances/scalar_mean_real", simple_value=scalar_avg_real)
                            inception_summary.value.add(
                                tag="distances/scalar_sdev_real", simple_value=scalar_sdev_real)
                            summary_writer.add_summary(inception_summary, step)
            except KeyboardInterrupt as e:
                print("Manual interrupt occurred.")
            except Exception as e:
                print(e)
            finally:
                coord.request_stop()
                coord.join(threads)
                print('Finished training.')
                saver.save(sess, SAVE_DIR, global_step=step)
                print("Model " + MODEL_NAME +
                      " saved in file: {} at step {}".format(SAVE_DIR, step))
Exemplo n.º 20
0
    def calculate_fid(self):
        import fid, pickle
        import tensorflow as tf

        stats_path = "fid_stats_cifar10_train.npz"  # training set statistics
        inception_path = fid.check_or_download_inception(
            "./tmp/"
        )  # download inception network

        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config)
        sigmas = sigmas_th.cpu().numpy()

        fids = {}
        for ckpt in tqdm.tqdm(
            range(
                self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000
            ),
            desc="processing ckpt",
        ):
            states = torch.load(
                os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"),
                map_location=self.config.device,
            )

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            num_iters = (
                self.config.fast_fid.num_samples // self.config.fast_fid.batch_size
            )
            output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt))
            os.makedirs(output_path, exist_ok=True)
            for i in range(num_iters):
                init_samples = torch.rand(
                    self.config.fast_fid.batch_size,
                    self.config.data.channels,
                    self.config.data.image_size,
                    self.config.data.image_size,
                    device=self.config.device,
                )
                init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics(
                    init_samples,
                    score,
                    sigmas,
                    self.config.fast_fid.n_steps_each,
                    self.config.fast_fid.step_lr,
                    verbose=self.config.fast_fid.verbose,
                )

                final_samples = all_samples[-1]
                for id, sample in enumerate(final_samples):
                    sample = sample.view(
                        self.config.data.channels,
                        self.config.data.image_size,
                        self.config.data.image_size,
                    )

                    sample = inverse_data_transform(self.config, sample)

                    save_image(
                        sample, os.path.join(output_path, "sample_{}.png".format(id))
                    )

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f["mu"][:], f["sigma"][:]
            f.close()

            fid.create_inception_graph(
                inception_path
            )  # load the graph into the current TF graph
            final_samples = (
                (final_samples - final_samples.min())
                / (final_samples.max() - final_samples.min()).data.cpu().numpy()
                * 255
            )
            final_samples = np.transpose(final_samples, [0, 2, 3, 1])
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    final_samples, sess, batch_size=100
                )

            fid_value = fid.calculate_frechet_distance(
                mu_gen, sigma_gen, mu_real, sigma_real
            )
            print("FID: %s" % fid_value)

        with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle:
            pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 21
0
from imageio import imread
import tensorflow as tf

########
# PATHS
########
# data_path = '/media/data1/vox/Resized/test' # set path to training set images
data_path = '/home/yuthon/Workspace/pix2pixHD/checkpoints/vox_embedded_ssim_4gpus_new_eval/eval/vox_embedded_eval/generated' # set path to training set images
output_path = '/home/yuthon/Workspace/pix2pixHD/checkpoints/vox_embedded_ssim_4gpus_new_eval/eval/vox_embedded_eval/generated/fid_stats.npz' # path for where to store the statistics
# if you have downloaded and extracted
#   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
# set this path to the directory where the extracted files are, otherwise
# just set it to None and the script will later download the files for you
inception_path = None
print("check for inception model..", end=" ", flush=True)
inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary
print("ok")

# loads all images into memory (this might require a lot of RAM!)
# print("load images..", end=" " , flush=True)
# image_list = glob.glob(os.path.join(data_path, '*.jpg'))
# images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
# print("%d images found and loaded" % len(images))

print("create inception graph..", end=" ", flush=True)
fid.create_inception_graph(inception_path)  # load the graph into the current TF graph
print("ok")

print("calculte FID stats..", end=" ", flush=True)
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
Exemplo n.º 22
0
    def __init__(self, d_net, g_net, x_sampler, z_sampler, args, inception, log_dir, scale=10.0):
        self.model = args.model
        self.data = args.data
        self.log_dir = log_dir
        self.g_net = g_net
        self.d_net = d_net
        self.x_sampler = x_sampler
        self.z_sampler = z_sampler
        self.x_dim = d_net.x_dim
        self.z_dim = g_net.z_dim
        self.beta = 0.9999
        self.d_iters = 1
        self.batch_size = 64
        self.inception = inception
        self.inception_path = fid.check_or_download_inception('./data/imagenet_model')

        if self.data == 'cifar10':
            self.stats_path = './data/fid_stats_cifar10_train.npz'
        elif self.data == 'stl10':
            self.stats_path = './data/fid_stats_stl10.npz'

        self.x = tf.placeholder(tf.float32, [None] + self.x_dim, name='x')
        self.z = tf.placeholder(tf.float32, [None] + [self.z_dim], name='z')

        self.x_ = self.g_net(self.z)
        self.d = self.d_net(self.x)
        self.d_ = self.d_net(self.x_, reuse=True)

        self.g_loss = tf.reduce_mean(self.d_)
        self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)

        epsilon = tf.random_uniform([], 0.0, 1.0)
        x_hat = epsilon * self.x + (1 - epsilon) * self.x_
        d_hat = self.d_net(x_hat, reuse=True)

        ddx = tf.gradients(d_hat, x_hat)[0]
        print(ddx.get_shape().as_list())
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
        self.gp_loss = tf.reduce_mean(tf.square(ddx - 1.0) * scale)
        self.d_loss_reg = self.d_loss + self.gp_loss

        '''
        print('gen vars')
        for var_ in self.g_net.vars:
            print(var_.name)

        print('gen_ema vars')
        for var_ in self.g_ema.vars:
            print(var_.name)
        '''

        ################################################################
        self.d_adam, self.g_adam = None, None

        with tf.variable_scope('var_prim'):
            self.disc_vprim = [tf.Variable(var.initialized_value()) for var in self.d_net.vars]
            self.gen_vprim = [tf.Variable(var.initialized_value()) for var in self.g_net.vars]

        self.varprim = self.disc_vprim + self.gen_vprim
        self.variables = self.d_net.vars + self.g_net.vars

        self.copy_w_to_wprim = [wprim.assign(w) for wprim, w in zip(self.varprim, self.variables)]  # assign w to wprim
        self.assign = [w.assign(wprim) for wprim, w in zip(self.varprim, self.variables)]  # assign w to wprim

        optimizer1 = tf.train.AdamOptimizer(1e-4, beta1=0., beta2=0.9)
        optimizer2 = tf.train.AdamOptimizer(1e-4, beta1=0., beta2=0.9)

        d_grads = tf.gradients(self.d_loss_reg, self.d_net.vars)
        g_grads = tf.gradients(self.g_loss, self.g_net.vars)
        grads = d_grads + g_grads

        g_w = [(g, v) for (g, v) in zip(grads, self.variables)]

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.first_step = optimizer1.apply_gradients(g_w)  # update w with grad_w

        g_wprim = [(g, v) for (g, v) in zip(grads, self.varprim)]
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = optimizer2.apply_gradients(g_wprim)

        print('trainable')

        for var_ in tf.model_variables('g_woa'):
            print(var_.name)

        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        with self.sess:
            fid.create_inception_graph(self.inception_path)  # load the graph into the current TF graph
Exemplo n.º 23
0
parser.add_argument("image_path")
parser.add_argument("stats_path")
parser.add_argument("model_path")
parser.add_argument("output_file")
parser.add_argument("--gpu", default="-1")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
import glob
import numpy as np
import fid
from scipy.misc import imread
import tensorflow as tf
import datetime

inception_path = fid.check_or_download_inception(args.model_path)

if args.mode == "pre-calculate":
    print("load images..")
    image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
    images = np.array(
        [imread(image).astype(np.float32) for image in image_list])
    print("%d images found and loaded" % len(images))

    print("create inception graph..", end=" ", flush=True)
    fid.create_inception_graph(inception_path)
    print("ok")

    print("calculate FID stats..", end=" ", flush=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
Exemplo n.º 24
0
    def __init__(self, data_loader, config):
        # Fix seed
        torch.manual_seed(config.seed)
        torch.cuda.manual_seed(config.seed)

        # Data loader
        self.data_loader = data_loader

        # arch and loss
        self.arch = config.arch
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel
        self.extra = config.extra

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.optim = config.optim
        self.lr_scheduler = config.lr_scheduler
        self.g_beta1 = config.g_beta1
        self.d_beta1 = config.d_beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model
        self.momentum = config.momentum

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.backup_freq = config.backup_freq
        self.bup_path = config.bup_path
        self.metrics_path = config.metrics_path
        self.store_models_freq = config.store_models_freq

        # lookahead
        self.lookahead = config.lookahead
        self.lookahead_k = config.lookahead_k
        self.lookahead_super_slow_k = config.lookahead_super_slow_k
        self.lookahead_k_min = config.lookahead_k_min
        self.lookahead_k_max = config.lookahead_k_max
        self.lookahead_alpha = config.lookahead_alpha

        self.build_model()

        # imagenet
        if self.dataset == 'imagenet':
            z_ = inception_utils.prepare_z_(self.batch_size,
                                            self.z_dim,
                                            device='cuda',
                                            z_var=1.0)
            # Prepare Sample function for use with inception metrics
            self.sample_G_func = functools.partial(inception_utils.sample,
                                                   G=self.G,
                                                   z_=z_)
            self.sample_G_ema_func = functools.partial(inception_utils.sample,
                                                       G=self.G_ema,
                                                       z_=z_)
            self.sample_G_ema_slow_func = functools.partial(
                inception_utils.sample, G=self.G_ema_slow, z_=z_)
            # Prepare inception metrics: FID and IS
            self.get_inception_metrics = inception_utils.prepare_inception_metrics(
                dataset="./I32", parallel=False, no_fid=False)

        self.best_path = config.best_path  # dir for best-perf checkpoint

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

        self.info_logger = setup_logger(self.log_path)
        self.info_logger.info(config)
        self.cont = config.cont
        self.fid_freq = config.fid_freq

        if self.fid_freq > 0 and self.dataset != 'imagenet':
            self.fid_json_file = os.path.join(self.model_save_path, '../FID',
                                              'fid.json')
            self.sample_size_fid = config.sample_size_fid
            if self.cont and os.path.isfile(self.fid_json_file):
                # load json files with fid scores
                self.fid_scores = load_json(self.fid_json_file)
            else:
                self.fid_scores = []
            sample_noise = torch.FloatTensor(self.sample_size_fid,
                                             self.z_dim).normal_()
            self.fid_noise_loader = torch.utils.data.DataLoader(sample_noise,
                                                                batch_size=200,
                                                                shuffle=False)
            # Inception Network
            _INCEPTION_PTH = fid.check_or_download_inception(
                './precalculated_statistics/inception-2015-12-05.pb')
            self.info_logger.info(
                'Loading the Inception Network from: {}'.format(
                    _INCEPTION_PTH))
            fid.create_inception_graph(
                _INCEPTION_PTH)  # load the graph into the current TF graph
            _gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
            # _gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.4)
            self.fid_session = tf.Session(config=tf.ConfigProto(
                gpu_options=_gpu_options))
            self.info_logger.info(
                'Loading real data FID stats from: {}'.format(
                    config.fid_stats_path))
            _real_fid = np.load(config.fid_stats_path)
            self.mu_real, self.sigma_real = _real_fid['mu'][:], _real_fid[
                'sigma'][:]
            _real_fid.close()
            make_folder(os.path.dirname(self.fid_json_file))
        elif self.fid_freq > 0:
            # make_folder(self.path)
            self.metrics_json_file = os.path.join(self.metrics_path,
                                                  'metrics.json')
Exemplo n.º 25
0
    def calculate_fid(self):
        import fid
        import tensorflow as tf

        num_of_step = 500
        bs = 100

        sigmas = np.exp(
            np.linspace(np.log(self.config.model.sigma_begin),
                        np.log(self.config.model.sigma_end),
                        self.config.model.num_classes))
        stats_path = 'fid_stats_cifar10_train.npz'  # training set statistics
        inception_path = fid.check_or_download_inception(
            None)  # download inception network

        print('Load checkpoint from' + self.args.log)
        #for epochs in range(140000, 200001, 1000):
        for epochs in [149000]:
            states = torch.load(os.path.join(
                self.args.log, 'checkpoint_' + str(epochs) + '.pth'),
                                map_location=self.config.device)
            #states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
            score = CondRefineNetDilated(self.config).to(self.config.device)
            score = torch.nn.DataParallel(score)

            score.load_state_dict(states[0])

            score.eval()

            if self.config.data.dataset == 'MNIST':
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 1, 28, 28, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            else:
                print("Begin epochs", epochs)
                samples = torch.rand(bs, 3, 32, 32, device=self.config.device)
                all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                    samples, score, sigmas, 100, 0.00002)
                images = all_samples.mul_(255).add_(0.5).clamp_(
                    0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                for j in range(num_of_step - 1):
                    samples = torch.rand(bs,
                                         3,
                                         32,
                                         32,
                                         device=self.config.device)
                    all_samples = self.anneal_Langevin_dynamics_GenerateImages(
                        samples, score, sigmas, 100, 0.00002)
                    images_new = all_samples.mul_(255).add_(0.5).clamp_(
                        0, 255).permute(0, 2, 3, 1).to('cpu').numpy()
                    images = np.concatenate((images, images_new), axis=0)

            # load precalculated training set statistics
            f = np.load(stats_path)
            mu_real, sigma_real = f['mu'][:], f['sigma'][:]
            f.close()

            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)

            fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen,
                                                       mu_real, sigma_real)
            print("FID: %s" % fid_value)
Exemplo n.º 26
0
def fid_example():
    # Paths
    re_est_gth = False

    dbname = 'cifar10'
    input_dir = '../../gan/output/'
    start = 10000
    niters = 300000
    step = 10000

    if dbname == 'cifar10':
        model = 'cifar10_wgangp_dcgan_wdis_lp_10_300000'
    elif dbname == 'stl10':
        model = 'stl10_distgan_resnet_hinge_gngan_0_ssgan_3_ld_1.0_lg_0.010_300000'

    mu_gth_file = 'mu_gth_' + dbname + '_10k.npy'
    sigma_gth_file = 'sigma_gth_' + dbname + '_5k.npy'
    """
    # loads all images into memory (this might require a lot of RAM!)
    gth_list = glob.glob(os.path.join(gth_path, '*.jpg'))
    gen_list = glob.glob(os.path.join(gen_path, '*.jpg'))
    gth_images = np.array([imread(str(fn)).astype(np.float32) for fn in gth_list])
    gen_images = np.array([imread(str(fn)).astype(np.float32) for fn in gen_list])
    """
    """
    # load precalculated training set statistics
    f = np.load(path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    """
    print('FID ESTIMATE')

    import os
    import os.path

    os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    inception_path = fid.check_or_download_inception(
        '/tmp')  # download inception network

    logfile = os.path.join(
        input_dir, model,
        dbname + '_' + model + '_fid_%d_%d.txt' % (start, niters))
    print(logfile)
    fid_log = open(logfile, 'w')

    if os.path.isfile(mu_gth_file) and os.path.isfile(
            sigma_gth_file) and re_est_gth:
        fid.create_inception_graph(
            inception_path)  # load the graph into the current TF graph
        mu_gth = np.load(mu_gth_file)
        sigma_gth = np.load(sigma_gth_file)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for i in range(start, niters + 1, step):
                gen_path = os.path.join(input_dir, model, dbname, 'fake_%d' %
                                        i)  # set path to some generated images
                print('[%s]' % (gen_path))
                mu_gen, sigma_gen = fid._handle_path(gen_path, sess)
                fid_value = fid.calculate_frechet_distance(
                    mu_gen, sigma_gen, mu_gth, sigma_gth)
                strout = "step: %d - FID: %s" % (i, fid_value)
                print(strout)
                fid_log.write(strout + '\n')
                fid_log.flush()

    else:
        fid.create_inception_graph(
            inception_path)  # load the graph into the current TF graph
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            gth_path = os.path.join(
                input_dir, model, dbname,
                'real')  # set path to some ground truth images
            mu_gth, sigma_gth = fid._handle_path(gth_path, sess)
            for i in range(start, niters + 1, step):
                gen_path = os.path.join(input_dir, model, dbname, 'fake_%d' %
                                        i)  # set path to some generated images
                print('[%s]' % (gen_path))
                mu_gen, sigma_gen = fid._handle_path(gen_path, sess)
                fid_value = fid.calculate_frechet_distance(
                    mu_gen, sigma_gen, mu_gth, sigma_gth)
                strout = "step: %d - FID: %s" % (i, fid_value)
                print(strout)
                fid_log.write(strout + '\n')
                fid_log.flush()

        np.save(mu_gth_file, mu_gth)
        np.save(sigma_gth_file, sigma_gth)

    return fid_value