コード例 #1
0
ファイル: metrics.py プロジェクト: lebrice/blurred-GAN
    def update_state(self, real_minibatch, fake_minibatch, *args, **kwargs):
        if len(self.resolutions) == 0:
            res = real_minibatch.shape[1]
            while res >= 16:
                self.resolutions.append(res)
                res //= 2
            self.real_descriptors = [[] for res in self.resolutions]
            self.fake_descriptors = [[] for res in self.resolutions]

        for lod, level in enumerate(
                sw.generate_laplacian_pyramid(real_minibatch,
                                              len(self.resolutions))):
            desc = sw.get_descriptors_for_minibatch(level, self.nhood_size,
                                                    self.nhoods_per_image)
            self.real_descriptors[lod].append(desc)

        for lod, level in enumerate(
                sw.generate_laplacian_pyramid(real_minibatch,
                                              len(self.resolutions))):
            desc = sw.get_descriptors_for_minibatch(level, self.nhood_size,
                                                    self.nhoods_per_image)
            self.fake_descriptors[lod].append(desc)
コード例 #2
0
def calc_sliced_wasserstein_scores(run_id,
                                   log='sliced-wasserstein.txt',
                                   resolution_min=16,
                                   resolution_max=1024,
                                   num_images=8192,
                                   nhoods_per_image=64,
                                   nhood_size=7,
                                   dir_repeats=1,
                                   dirs_per_repeat=147,
                                   minibatch_size=16):

    import sliced_wasserstein
    result_subdir = misc.locate_result_subdir(run_id)
    network_pkls = misc.list_network_pkls(result_subdir)
    misc.set_output_log_file(os.path.join(result_subdir, log))

    # Load dataset.
    print 'Loading dataset...'
    training_set, drange_orig = load_dataset_for_previous_run(result_subdir)
    assert training_set.shape[1] == 3  # RGB
    assert num_images % minibatch_size == 0

    # Select resolutions.
    resolution_full = training_set.shape[3]
    resolution_min = min(resolution_min, resolution_full)
    resolution_max = min(resolution_max, resolution_full)
    base_lod = int(np.log2(resolution_full)) - int(np.log2(resolution_max))
    resolutions = [
        2**i for i in xrange(int(np.log2(resolution_max)),
                             int(np.log2(resolution_min)) - 1, -1)
    ]

    # Collect descriptors for reals.
    print 'Extracting descriptors for reals...',
    time_begin = time.time()
    desc_real = [[] for res in resolutions]
    desc_test = [[] for res in resolutions]
    for minibatch_begin in xrange(0, num_images, minibatch_size):
        minibatch = training_set.get_random_minibatch(minibatch_size,
                                                      lod=base_lod)
        for lod, level in enumerate(
                sliced_wasserstein.generate_laplacian_pyramid(
                    minibatch, len(resolutions))):
            desc_real[lod].append(
                sliced_wasserstein.get_descriptors_for_minibatch(
                    level, nhood_size, nhoods_per_image))
            desc_test[lod].append(
                sliced_wasserstein.get_descriptors_for_minibatch(
                    level, nhood_size, nhoods_per_image))
    print 'done in %s' % misc.format_time(time.time() - time_begin)

    # Evaluate scores for reals.
    print 'Evaluating scores for reals...',
    time_begin = time.time()
    scores = []
    for lod, res in enumerate(resolutions):
        desc_real[lod] = sliced_wasserstein.finalize_descriptors(
            desc_real[lod])
        desc_test[lod] = sliced_wasserstein.finalize_descriptors(
            desc_test[lod])
        scores.append(
            sliced_wasserstein.sliced_wasserstein(desc_real[lod],
                                                  desc_test[lod], dir_repeats,
                                                  dirs_per_repeat))
    del desc_test
    print 'done in %s' % misc.format_time(time.time() - time_begin)

    # Print table header.
    print
    print '%-32s' % 'Case',
    for lod, res in enumerate(resolutions):
        print '%-12s' % ('%dx%d' % (res, res)),
    print 'Average'
    print '%-32s' % '---',
    for lod, res in enumerate(resolutions):
        print '%-12s' % '---',
    print '---'
    print '%-32s' % 'reals',
    for lod, res in enumerate(resolutions):
        print '%-12.6f' % scores[lod],
    print '%.6f' % np.mean(scores)

    # Process each network snapshot.
    for network_idx, network_pkl in enumerate(network_pkls):
        print '%-32s' % os.path.basename(network_pkl),
        net = imgapi_load_net(run_id=result_subdir,
                              snapshot=network_pkl,
                              num_example_latents=num_images,
                              random_seed=network_idx)

        # Extract descriptors for generated images.
        desc_fake = [[] for res in resolutions]
        for minibatch_begin in xrange(0, num_images, minibatch_size):
            latents = net.example_latents[minibatch_begin:minibatch_begin +
                                          minibatch_size]
            labels = net.example_labels[minibatch_begin:minibatch_begin +
                                        minibatch_size]
            minibatch = imgapi_generate_batch(net,
                                              latents,
                                              labels,
                                              minibatch_size=minibatch_size,
                                              convert_to_uint8=True)
            minibatch = sliced_wasserstein.downscale_minibatch(
                minibatch, base_lod)
            for lod, level in enumerate(
                    sliced_wasserstein.generate_laplacian_pyramid(
                        minibatch, len(resolutions))):
                desc_fake[lod].append(
                    sliced_wasserstein.get_descriptors_for_minibatch(
                        level, nhood_size, nhoods_per_image))

        # Evaluate scores.
        scores = []
        for lod, res in enumerate(resolutions):
            desc_fake[lod] = sliced_wasserstein.finalize_descriptors(
                desc_fake[lod])
            scores.append(
                sliced_wasserstein.sliced_wasserstein(desc_real[lod],
                                                      desc_fake[lod],
                                                      dir_repeats,
                                                      dirs_per_repeat))
        del desc_fake

        # Report results.
        for lod, res in enumerate(resolutions):
            print '%-12.6f' % scores[lod],
        print '%.6f' % np.mean(scores)
    print
    print 'Done.'