Пример #1
0
    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this fucntion, you should first call <BaseModel.__init__(self, opt)>
        Then, you need to define four lists:
            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
            -- self.model_names (str list):         specify the images that you want to display and save.
            -- self.visual_names (str list):        define networks used in our training.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(
            self.gpu_ids[0])) if self.gpu_ids else torch.device(
                'cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(
            opt.checkpoints_dir,
            opt.name)  # save all the checkpoints to save_dir
        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

        # scores init
        if self.opt.use_pytorch_scores and self.opt.score_name is not None:
            no_FID = True
            no_IS = True
            parallel = len(opt.gpu_ids) > 1
            for name in self.opt.score_name:
                if name == 'FID':
                    no_FID = False
                if name == 'IS':
                    no_IS = False
            self.get_inception_metrics = inception_utils.prepare_inception_metrics(
                opt.dataset_name, parallel, no_IS, no_FID)
        else:
            for name in self.opt.score_name:
                if name == 'FID':
                    STAT_FILE = self.opt.fid_stat_file
                    INCEPTION_PATH = "./inception_v3/"

                    print("load train stats.. ")
                    # load precalculated training set statistics
                    f = np.load(STAT_FILE)
                    self.mu_real, self.sigma_real = f['mu'][:], f['sigma'][:]
                    f.close()
                    print("ok")

                    inception_path = fid.check_or_download_inception(
                        INCEPTION_PATH)  # download inception network
                    fid.create_inception_graph(
                        inception_path
                    )  # load the graph into the current TF graph

                    config = tf.ConfigProto()
                    config.gpu_options.allow_growth = True
                    self.sess = tf.Session(config=config)
                    self.sess.run(tf.global_variables_initializer())
Пример #2
0
import numpy as np
import os
import gzip, pickle
import tensorflow as tf
from imageio import imread
from scipy import linalg
import pathlib
import urllib
import warnings
import tqdm
from TTUR.fid import check_or_download_inception, _handle_path, calculate_frechet_distance, create_inception_graph

if (__name__ == '__main__'):
    images_path = 'fashion_mnist_images'
    assert os.path.exists(
        images_path), 'Need to run ../datasets.save_fashion_mnist_to_samples()'

    inception_path = None
    inception_path = check_or_download_inception(inception_path)

    create_inception_graph(str(inception_path))
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        _handle_path(images_path,
                     sess,
                     low_profile=True,
                     stats_path='FashionMNIST.npz')
Пример #3
0
    args = parser.parse_args()

    dataset_dir = args.dataset_dir
    dataset_name = args.dataset_name
    dataset_stats_dir = args.dataset_stats_dir
    dataset_stats_name = args.dataset_stats_name
    dataset_stats_file = dataset_stats_dir + dataset_stats_name
    fid_img_dir = args.fid_img_dir

    if dataset_stats_name is None:

        if not os.path.exists(dataset_stats_dir):
            os.makedirs(dataset_stats_dir)

        #Find the statistics for the dataset
        inception_path = fid.check_or_download_inception(None)
        fid.create_inception_graph(str(inception_path))

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            mu, sigma = fid._handle_path(dataset_dir, sess, low_profile=True)

        dataset_stats_file = dataset_stats_dir + '{0}.npz'.format(dataset_name)

        #Save the dataset statistics
        np.savez(dataset_stats_file, mu=mu, sigma=sigma)

    #Keep track of the FID scores
    fid_score = OrderedDict()
    paths = [dataset_stats_file]
    paths.append(' ')