def calc_inception_scores(run_id, log='inception.txt', num_images=50000, minibatch_size=16, eval_reals=True, reverse_order=False): result_subdir = misc.locate_result_subdir(run_id) network_pkls = misc.list_network_pkls(result_subdir) misc.set_output_log_file(os.path.join(result_subdir, log)) print 'Importing inception score module...' import inception_score def calc_inception_score(images): if images.shape[1] == 1: images = images.repeat(3, axis=1) images = list(images.transpose(0, 2, 3, 1)) return inception_score.get_inception_score(images) # Load dataset. training_set, drange_orig = load_dataset_for_previous_run(result_subdir, shuffle=False) reals, labels = training_set.get_random_minibatch(num_images, labels=True) # Evaluate reals. if eval_reals: print 'Evaluating inception score for reals...' time_begin = time.time() mean, std = calc_inception_score(reals) print 'Done in %s' % misc.format_time(time.time() - time_begin) print '%-32s mean %-8.4f std %-8.4f' % ('reals', mean, std) # Evaluate each network snapshot. network_pkls = list(enumerate(network_pkls)) if reverse_order: network_pkls = network_pkls[::-1] for network_idx, network_pkl in network_pkls: print '%-32s' % os.path.basename(network_pkl), net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images, random_seed=network_idx, load_dataset=False) fakes = imgapi_generate_batch(net, net.example_latents, np.random.permutation(labels), minibatch_size=minibatch_size, convert_to_uint8=True) mean, std = calc_inception_score(fakes) print 'mean %-8.4f std %-8.4f' % (mean, std) print print 'Done.'
def run_all_snapshots(submit_config, metric_args, run_id): ctx = dnnlib.RunContext(submit_config) tflib.init_tf() print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) run_dir = misc.locate_run_dir(run_id) network_pkls = misc.list_network_pkls(run_dir) metric = dnnlib.util.call_func_by_name(**metric_args) print() for idx, network_pkl in enumerate(network_pkls): ctx.update('', idx, len(network_pkls)) metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) print() ctx.close()
def generate_fake_images_all(run_id, out_dir, num_pngs, image_shrink=1, random_seed=1000, minibatch_size=1, num_pkls=50): random_state = np.random.RandomState(random_seed) out_dir = os.path.join(out_dir, str(run_id)) result_subdir = misc.locate_result_subdir(run_id) snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False) assert len(snapshot_pkls) >= 1 for snapshot_idx, snapshot_pkl in enumerate(snapshot_pkls[:num_pkls]): prefix = 'network-snapshot-' postfix = '.pkl' snapshot_name = os.path.basename(snapshot_pkl) tmp_dir = os.path.join(out_dir, snapshot_name.split('.')[0]) if not os.path.isdir(tmp_dir): os.makedirs(tmp_dir) assert snapshot_name.startswith(prefix) and snapshot_name.endswith( postfix) snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)]) print('Loading network...') G, D, Gs = misc.load_network_pkl(snapshot_pkl) latents = misc.random_latents(num_pngs, Gs, random_state=random_state) labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=config.num_gpus * 32, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) for png_idx in range(num_pngs): print('Generating png to %s: %d / %d...' % (tmp_dir, png_idx, num_pngs), end='\r') if not os.path.exists( os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)): misc.save_image_grid( images[png_idx:png_idx + 1], os.path.join(tmp_dir, 'ProGAN_%08d.png' % png_idx), [0, 255], [1, 1]) print()
def evaluate_metrics(run_id, log, metrics, num_images, real_passes, minibatch_size=None): metric_class_names = { 'swd': 'metrics.sliced_wasserstein.API', 'fid': 'metrics.frechet_inception_distance.API', 'is': 'metrics.inception_score.API', 'msssim': 'metrics.ms_ssim.API', } # Locate training run and initialize logging. result_subdir = misc.locate_result_subdir(run_id) snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False) assert len(snapshot_pkls) >= 1 log_file = os.path.join(result_subdir, log) print('Logging output to', log_file) misc.set_output_log_file(log_file) # Initialize dataset and select minibatch size. dataset_obj, mirror_augment = misc.load_dataset_for_previous_run( result_subdir, verbose=True, shuffle_mb=0) if minibatch_size is None: minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256) # Initialize metrics. metric_objs = [] for name in metrics: class_name = metric_class_names.get(name, name) print('Initializing %s...' % class_name) class_def = tfutil.import_obj(class_name) image_shape = [3] + dataset_obj.shape[1:] obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size) tfutil.init_uninited_vars() mode = 'warmup' obj.begin(mode) for idx in range(10): obj.feed( mode, np.random.randint(0, 256, size=[minibatch_size] + image_shape, dtype=np.uint8)) obj.end(mode) metric_objs.append(obj) # Print table header. print() print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='') for obj in metric_objs: for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()): print('%-*s' % (len(fmt % 0), name), end='') print() print('%-10s%-12s' % ('---', '---'), end='') for obj in metric_objs: for fmt in obj.get_metric_formatting(): print('%-*s' % (len(fmt % 0), '---'), end='') print() # Feed in reals. for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]: print('%-10s' % title, end='') time_begin = time.time() labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32) [obj.begin(mode) for obj in metric_objs] for begin in range(0, num_images, minibatch_size): end = min(begin + minibatch_size, num_images) images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin) if mirror_augment: images = misc.apply_mirror_augment(images) if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB [obj.feed(mode, images) for obj in metric_objs] results = [obj.end(mode) for obj in metric_objs] print('%-12s' % misc.format_time(time.time() - time_begin), end='') for obj, vals in zip(metric_objs, results): for val, fmt in zip(vals, obj.get_metric_formatting()): print(fmt % val, end='') print() # Evaluate each network snapshot. for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)): prefix = 'network-snapshot-' postfix = '.pkl' snapshot_name = os.path.basename(snapshot_pkl) assert snapshot_name.startswith(prefix) and snapshot_name.endswith( postfix) snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)]) print('%-10d' % snapshot_kimg, end='') mode = 'fakes' [obj.begin(mode) for obj in metric_objs] time_begin = time.time() with tf.Graph().as_default(), tfutil.create_session( config.tf_config).as_default(): G, D, Gs = misc.load_pkl(snapshot_pkl) for begin in range(0, num_images, minibatch_size): end = min(begin + minibatch_size, num_images) latents = misc.random_latents(end - begin, Gs) images = Gs.run(latents, labels[begin:end], num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_dtype=np.uint8) if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB [obj.feed(mode, images) for obj in metric_objs] results = [obj.end(mode) for obj in metric_objs] print('%-12s' % misc.format_time(time.time() - time_begin), end='') for obj, vals in zip(metric_objs, results): for val, fmt in zip(vals, obj.get_metric_formatting()): print(fmt % val, end='') print() print()
def calc_mnistrgb_histogram(run_id, num_images=25600, log='histogram.txt', minibatch_size=256, num_evals=10, eval_reals=True, final_only=False): # Load the classification network. # NOTE: The PKL can be downloaded from https://drive.google.com/open?id=0B4qLcYyJmiz0NHFULTdYc05lX0U net = network.load_mnist_classifier( os.path.join(config.data_dir, '../networks/mnist_classifier_weights.pkl')) input_var = T.tensor4() output_expr = lasagne.layers.get_output(net, inputs=input_var, deterministic=True) classify_fn = theano.function([input_var], [output_expr]) # Process folders print 'Processing directory %s' % (run_id) result_subdir = misc.locate_result_subdir(run_id) network_pkls = misc.list_network_pkls(result_subdir) misc.set_output_log_file(os.path.join(result_subdir, log)) if final_only: network_pkls = [network_pkls[-1]] # Histogram calculation. def calc_histogram(images_all): scores = [] divergences = [] for i in range(num_evals): images = images_all[i * num_images:(i + 1) * num_images] model = [0.] * 1000 for s in range(0, images.shape[0], minibatch_size): img = images[s:s + minibatch_size].reshape((-1, 1, 32, 32)) res = np.asarray(classify_fn(img)[0]) res = np.argmax(res, axis=1) res = res.reshape((-1, 3)) * np.asarray([[1, 10, 100]]) res = np.sum(res, axis=1) for x in res: model[int(x)] += 1. model = np.array([b / 25600. for b in model if b > 0]) # remove empty buckets, normalize data = np.array([1. / 1000] * len(model)) # corresponding ideal counts scores.append(len(model)) divergences.append(np.sum( model * np.log(model / data))) # reverse KL? Metz et al. say KL(model || data) scores = np.asarray(scores, dtype=np.float32) return np.mean(scores), np.mean(divergences) # Load dataset. training_set, drange_orig = load_dataset_for_previous_run(result_subdir, shuffle=False) reals, labels = training_set.get_random_minibatch(num_images * num_evals, labels=True) # Evaluate reals. if eval_reals: print 'Evaluating histogram for reals...' time_begin = time.time() mean, kld = calc_histogram(reals) print 'Done in %s' % misc.format_time(time.time() - time_begin) print 'mean %-8.4f kld %-8.4f' % (mean, kld) # Evaluate each network snapshot. latents = None for network_idx, network_pkl in enumerate(network_pkls): print '%-32s' % os.path.basename(network_pkl), net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images * num_evals) fakes = imgapi_generate_batch(net, net.example_latents, labels, minibatch_size=minibatch_size, convert_to_uint8=True) mean, kld = calc_histogram(fakes) print 'mean %-8.4f kld %-8.4f' % (mean, kld)
def calc_sliced_wasserstein_scores(run_id, log='sliced-wasserstein.txt', resolution_min=16, resolution_max=1024, num_images=8192, nhoods_per_image=64, nhood_size=7, dir_repeats=1, dirs_per_repeat=147, minibatch_size=16): import sliced_wasserstein result_subdir = misc.locate_result_subdir(run_id) network_pkls = misc.list_network_pkls(result_subdir) misc.set_output_log_file(os.path.join(result_subdir, log)) # Load dataset. print 'Loading dataset...' training_set, drange_orig = load_dataset_for_previous_run(result_subdir) assert training_set.shape[1] == 3 # RGB assert num_images % minibatch_size == 0 # Select resolutions. resolution_full = training_set.shape[3] resolution_min = min(resolution_min, resolution_full) resolution_max = min(resolution_max, resolution_full) base_lod = int(np.log2(resolution_full)) - int(np.log2(resolution_max)) resolutions = [ 2**i for i in xrange(int(np.log2(resolution_max)), int(np.log2(resolution_min)) - 1, -1) ] # Collect descriptors for reals. print 'Extracting descriptors for reals...', time_begin = time.time() desc_real = [[] for res in resolutions] desc_test = [[] for res in resolutions] for minibatch_begin in xrange(0, num_images, minibatch_size): minibatch = training_set.get_random_minibatch(minibatch_size, lod=base_lod) for lod, level in enumerate( sliced_wasserstein.generate_laplacian_pyramid( minibatch, len(resolutions))): desc_real[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) desc_test[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) print 'done in %s' % misc.format_time(time.time() - time_begin) # Evaluate scores for reals. print 'Evaluating scores for reals...', time_begin = time.time() scores = [] for lod, res in enumerate(resolutions): desc_real[lod] = sliced_wasserstein.finalize_descriptors( desc_real[lod]) desc_test[lod] = sliced_wasserstein.finalize_descriptors( desc_test[lod]) scores.append( sliced_wasserstein.sliced_wasserstein(desc_real[lod], desc_test[lod], dir_repeats, dirs_per_repeat)) del desc_test print 'done in %s' % misc.format_time(time.time() - time_begin) # Print table header. print print '%-32s' % 'Case', for lod, res in enumerate(resolutions): print '%-12s' % ('%dx%d' % (res, res)), print 'Average' print '%-32s' % '---', for lod, res in enumerate(resolutions): print '%-12s' % '---', print '---' print '%-32s' % 'reals', for lod, res in enumerate(resolutions): print '%-12.6f' % scores[lod], print '%.6f' % np.mean(scores) # Process each network snapshot. for network_idx, network_pkl in enumerate(network_pkls): print '%-32s' % os.path.basename(network_pkl), net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images, random_seed=network_idx) # Extract descriptors for generated images. desc_fake = [[] for res in resolutions] for minibatch_begin in xrange(0, num_images, minibatch_size): latents = net.example_latents[minibatch_begin:minibatch_begin + minibatch_size] labels = net.example_labels[minibatch_begin:minibatch_begin + minibatch_size] minibatch = imgapi_generate_batch(net, latents, labels, minibatch_size=minibatch_size, convert_to_uint8=True) minibatch = sliced_wasserstein.downscale_minibatch( minibatch, base_lod) for lod, level in enumerate( sliced_wasserstein.generate_laplacian_pyramid( minibatch, len(resolutions))): desc_fake[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) # Evaluate scores. scores = [] for lod, res in enumerate(resolutions): desc_fake[lod] = sliced_wasserstein.finalize_descriptors( desc_fake[lod]) scores.append( sliced_wasserstein.sliced_wasserstein(desc_real[lod], desc_fake[lod], dir_repeats, dirs_per_repeat)) del desc_fake # Report results. for lod, res in enumerate(resolutions): print '%-12.6f' % scores[lod], print '%.6f' % np.mean(scores) print print 'Done.'