def update_state(self, real_minibatch, fake_minibatch, *args, **kwargs): if len(self.resolutions) == 0: res = real_minibatch.shape[1] while res >= 16: self.resolutions.append(res) res //= 2 self.real_descriptors = [[] for res in self.resolutions] self.fake_descriptors = [[] for res in self.resolutions] for lod, level in enumerate( sw.generate_laplacian_pyramid(real_minibatch, len(self.resolutions))): desc = sw.get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image) self.real_descriptors[lod].append(desc) for lod, level in enumerate( sw.generate_laplacian_pyramid(real_minibatch, len(self.resolutions))): desc = sw.get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image) self.fake_descriptors[lod].append(desc)
def calc_sliced_wasserstein_scores(run_id, log='sliced-wasserstein.txt', resolution_min=16, resolution_max=1024, num_images=8192, nhoods_per_image=64, nhood_size=7, dir_repeats=1, dirs_per_repeat=147, minibatch_size=16): import sliced_wasserstein result_subdir = misc.locate_result_subdir(run_id) network_pkls = misc.list_network_pkls(result_subdir) misc.set_output_log_file(os.path.join(result_subdir, log)) # Load dataset. print 'Loading dataset...' training_set, drange_orig = load_dataset_for_previous_run(result_subdir) assert training_set.shape[1] == 3 # RGB assert num_images % minibatch_size == 0 # Select resolutions. resolution_full = training_set.shape[3] resolution_min = min(resolution_min, resolution_full) resolution_max = min(resolution_max, resolution_full) base_lod = int(np.log2(resolution_full)) - int(np.log2(resolution_max)) resolutions = [ 2**i for i in xrange(int(np.log2(resolution_max)), int(np.log2(resolution_min)) - 1, -1) ] # Collect descriptors for reals. print 'Extracting descriptors for reals...', time_begin = time.time() desc_real = [[] for res in resolutions] desc_test = [[] for res in resolutions] for minibatch_begin in xrange(0, num_images, minibatch_size): minibatch = training_set.get_random_minibatch(minibatch_size, lod=base_lod) for lod, level in enumerate( sliced_wasserstein.generate_laplacian_pyramid( minibatch, len(resolutions))): desc_real[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) desc_test[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) print 'done in %s' % misc.format_time(time.time() - time_begin) # Evaluate scores for reals. print 'Evaluating scores for reals...', time_begin = time.time() scores = [] for lod, res in enumerate(resolutions): desc_real[lod] = sliced_wasserstein.finalize_descriptors( desc_real[lod]) desc_test[lod] = sliced_wasserstein.finalize_descriptors( desc_test[lod]) scores.append( sliced_wasserstein.sliced_wasserstein(desc_real[lod], desc_test[lod], dir_repeats, dirs_per_repeat)) del desc_test print 'done in %s' % misc.format_time(time.time() - time_begin) # Print table header. print print '%-32s' % 'Case', for lod, res in enumerate(resolutions): print '%-12s' % ('%dx%d' % (res, res)), print 'Average' print '%-32s' % '---', for lod, res in enumerate(resolutions): print '%-12s' % '---', print '---' print '%-32s' % 'reals', for lod, res in enumerate(resolutions): print '%-12.6f' % scores[lod], print '%.6f' % np.mean(scores) # Process each network snapshot. for network_idx, network_pkl in enumerate(network_pkls): print '%-32s' % os.path.basename(network_pkl), net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images, random_seed=network_idx) # Extract descriptors for generated images. desc_fake = [[] for res in resolutions] for minibatch_begin in xrange(0, num_images, minibatch_size): latents = net.example_latents[minibatch_begin:minibatch_begin + minibatch_size] labels = net.example_labels[minibatch_begin:minibatch_begin + minibatch_size] minibatch = imgapi_generate_batch(net, latents, labels, minibatch_size=minibatch_size, convert_to_uint8=True) minibatch = sliced_wasserstein.downscale_minibatch( minibatch, base_lod) for lod, level in enumerate( sliced_wasserstein.generate_laplacian_pyramid( minibatch, len(resolutions))): desc_fake[lod].append( sliced_wasserstein.get_descriptors_for_minibatch( level, nhood_size, nhoods_per_image)) # Evaluate scores. scores = [] for lod, res in enumerate(resolutions): desc_fake[lod] = sliced_wasserstein.finalize_descriptors( desc_fake[lod]) scores.append( sliced_wasserstein.sliced_wasserstein(desc_real[lod], desc_fake[lod], dir_repeats, dirs_per_repeat)) del desc_fake # Report results. for lod, res in enumerate(resolutions): print '%-12.6f' % scores[lod], print '%.6f' % np.mean(scores) print print 'Done.'