def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this fucntion, you should first call <BaseModel.__init__(self, opt)> Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): define networks used in our training. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format( self.gpu_ids[0])) if self.gpu_ids else torch.device( 'cpu') # get device name: CPU or GPU self.save_dir = os.path.join( opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' # scores init if self.opt.use_pytorch_scores and self.opt.score_name is not None: no_FID = True no_IS = True parallel = len(opt.gpu_ids) > 1 for name in self.opt.score_name: if name == 'FID': no_FID = False if name == 'IS': no_IS = False self.get_inception_metrics = inception_utils.prepare_inception_metrics( opt.dataset_name, parallel, no_IS, no_FID) else: for name in self.opt.score_name: if name == 'FID': STAT_FILE = self.opt.fid_stat_file INCEPTION_PATH = "./inception_v3/" print("load train stats.. ") # load precalculated training set statistics f = np.load(STAT_FILE) self.mu_real, self.sigma_real = f['mu'][:], f['sigma'][:] f.close() print("ok") inception_path = fid.check_or_download_inception( INCEPTION_PATH) # download inception network fid.create_inception_graph( inception_path ) # load the graph into the current TF graph config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.sess.run(tf.global_variables_initializer())
import numpy as np import os import gzip, pickle import tensorflow as tf from imageio import imread from scipy import linalg import pathlib import urllib import warnings import tqdm from TTUR.fid import check_or_download_inception, _handle_path, calculate_frechet_distance, create_inception_graph if (__name__ == '__main__'): images_path = 'fashion_mnist_images' assert os.path.exists( images_path), 'Need to run ../datasets.save_fashion_mnist_to_samples()' inception_path = None inception_path = check_or_download_inception(inception_path) create_inception_graph(str(inception_path)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) _handle_path(images_path, sess, low_profile=True, stats_path='FashionMNIST.npz')
args = parser.parse_args() dataset_dir = args.dataset_dir dataset_name = args.dataset_name dataset_stats_dir = args.dataset_stats_dir dataset_stats_name = args.dataset_stats_name dataset_stats_file = dataset_stats_dir + dataset_stats_name fid_img_dir = args.fid_img_dir if dataset_stats_name is None: if not os.path.exists(dataset_stats_dir): os.makedirs(dataset_stats_dir) #Find the statistics for the dataset inception_path = fid.check_or_download_inception(None) fid.create_inception_graph(str(inception_path)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu, sigma = fid._handle_path(dataset_dir, sess, low_profile=True) dataset_stats_file = dataset_stats_dir + '{0}.npz'.format(dataset_name) #Save the dataset statistics np.savez(dataset_stats_file, mu=mu, sigma=sigma) #Keep track of the FID scores fid_score = OrderedDict() paths = [dataset_stats_file] paths.append(' ')