Example #1
0
def imgapi_load_net(run_id, snapshot=None, random_seed=1000, num_example_latents=1000, load_dataset=True, compile_gen_fn=True):
    class Net: pass
    net = Net()
    net.result_subdir = misc.locate_result_subdir(run_id)
    net.network_pkl = misc.locate_network_pkl(net.result_subdir, snapshot)
    _, _, net.G = misc.load_pkl(net.network_pkl)

    # Generate example latents and labels.
    np.random.seed(random_seed)
    net.example_latents = random_latents(num_example_latents, net.G.input_shape)
    net.example_labels = np.zeros((num_example_latents, 0), dtype=np.float32)
    net.dynamic_range = [0, 255]
    if load_dataset:
        imgapi_load_dataset(net)

    # Compile Theano func.
    net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var')
    net.labels_var  = T.TensorType('float32', [False] * len(net.example_labels.shape)) ('labels_var')

    if hasattr(net.G, 'cur_lod'):
        net.lod = net.G.cur_lod.get_value()
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, min_lod=net.lod, max_lod=net.lod, ignore_unused_inputs=True)
    else:
        net.lod = 0.0
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True)

    net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], net.dynamic_range)
    if compile_gen_fn:
        imgapi_compile_gen_fn(net)
    return net
Example #2
0
def main():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    parser = argparse.ArgumentParser(
        description='converter to create pytorch models')
    #parser.add_argument('--id', type=int, required=True,
    #        help='number of model to convert')
    parser.add_argument('--id',
                        type=str,
                        required=True,
                        help='number of model to convert')
    parser.add_argument('--outdir', default=None)
    args = parser.parse_args()

    # Configuration
    snapshot = None  # Default, implies last snapshot

    # Get parameters from checkpoint
    tfutil.init_tf()
    directory = misc.locate_result_subdir(args.id)
    print('Loading snapshot from %s' % directory)
    G, D, Gs = misc.load_network_pkl(args.id, snapshot)
    print(G)
    print(D)
    print(Gs)

    # import pdb; pdb.set_trace()

    # model = from_tf_parameters(Gs.variables)
    model = from_tf_parameters(Gs.vars)
    if args.outdir is None:
        args.outdir = directory
    filename = os.path.join(args.outdir, 'generator.pth')
    print('Saving pytorch model as %s' % filename)
    torch.save(model.state_dict(), filename)
def calc_inception_scores(run_id,
                          log='inception.txt',
                          num_images=50000,
                          minibatch_size=16,
                          eval_reals=True,
                          reverse_order=False):
    result_subdir = misc.locate_result_subdir(run_id)
    network_pkls = misc.list_network_pkls(result_subdir)
    misc.set_output_log_file(os.path.join(result_subdir, log))

    print 'Importing inception score module...'
    import inception_score

    def calc_inception_score(images):
        if images.shape[1] == 1:
            images = images.repeat(3, axis=1)
        images = list(images.transpose(0, 2, 3, 1))
        return inception_score.get_inception_score(images)

    # Load dataset.
    training_set, drange_orig = load_dataset_for_previous_run(result_subdir,
                                                              shuffle=False)
    reals, labels = training_set.get_random_minibatch(num_images, labels=True)

    # Evaluate reals.
    if eval_reals:
        print 'Evaluating inception score for reals...'
        time_begin = time.time()
        mean, std = calc_inception_score(reals)
        print 'Done in %s' % misc.format_time(time.time() - time_begin)
        print '%-32s mean %-8.4f std %-8.4f' % ('reals', mean, std)

    # Evaluate each network snapshot.
    network_pkls = list(enumerate(network_pkls))
    if reverse_order:
        network_pkls = network_pkls[::-1]
    for network_idx, network_pkl in network_pkls:
        print '%-32s' % os.path.basename(network_pkl),
        net = imgapi_load_net(run_id=result_subdir,
                              snapshot=network_pkl,
                              num_example_latents=num_images,
                              random_seed=network_idx,
                              load_dataset=False)
        fakes = imgapi_generate_batch(net,
                                      net.example_latents,
                                      np.random.permutation(labels),
                                      minibatch_size=minibatch_size,
                                      convert_to_uint8=True)
        mean, std = calc_inception_score(fakes)
        print 'mean %-8.4f std %-8.4f' % (mean, std)
    print
    print 'Done.'
Example #4
0
def generate_fake_images_all(run_id,
                             out_dir,
                             num_pngs,
                             image_shrink=1,
                             random_seed=1000,
                             minibatch_size=1,
                             num_pkls=50):
    random_state = np.random.RandomState(random_seed)
    out_dir = os.path.join(out_dir, str(run_id))

    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1

    for snapshot_idx, snapshot_pkl in enumerate(snapshot_pkls[:num_pkls]):
        prefix = 'network-snapshot-'
        postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        tmp_dir = os.path.join(out_dir, snapshot_name.split('.')[0])
        if not os.path.isdir(tmp_dir):
            os.makedirs(tmp_dir)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(
            postfix)
        snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)])

        print('Loading network...')
        G, D, Gs = misc.load_network_pkl(snapshot_pkl)

        latents = misc.random_latents(num_pngs, Gs, random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=config.num_gpus * 32,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        for png_idx in range(num_pngs):
            print('Generating png to %s: %d / %d...' %
                  (tmp_dir, png_idx, num_pngs),
                  end='\r')
            if not os.path.exists(
                    os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)):
                misc.save_image_grid(
                    images[png_idx:png_idx + 1],
                    os.path.join(tmp_dir, 'ProGAN_%08d.png' % png_idx),
                    [0, 255], [1, 1])
    print()
Example #5
0
def evaluate_metrics(run_id,
                     log,
                     metrics,
                     num_images,
                     real_passes,
                     minibatch_size=None):
    metric_class_names = {
        'swd': 'metrics.sliced_wasserstein.API',
        'fid': 'metrics.frechet_inception_distance.API',
        'is': 'metrics.inception_score.API',
        'msssim': 'metrics.ms_ssim.API',
    }

    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
        result_subdir, verbose=True, shuffle_mb=0)
    if minibatch_size is None:
        minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(num_images=num_images,
                        image_shape=image_shape,
                        image_dtype=np.uint8,
                        minibatch_size=minibatch_size)
        tfutil.init_uninited_vars()
        mode = 'warmup'
        obj.begin(mode)
        for idx in range(10):
            obj.feed(
                mode,
                np.random.randint(0,
                                  256,
                                  size=[minibatch_size] + image_shape,
                                  dtype=np.uint8))
        obj.end(mode)
        metric_objs.append(obj)

    # Print table header.
    print()
    print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='')
    for obj in metric_objs:
        for name, fmt in zip(obj.get_metric_names(),
                             obj.get_metric_formatting()):
            print('%-*s' % (len(fmt % 0), name), end='')
    print()
    print('%-10s%-12s' % ('---', '---'), end='')
    for obj in metric_objs:
        for fmt in obj.get_metric_formatting():
            print('%-*s' % (len(fmt % 0), '---'), end='')
    print()

    # Feed in reals.
    for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]:
        print('%-10s' % title, end='')
        time_begin = time.time()
        labels = np.zeros([num_images, dataset_obj.label_size],
                          dtype=np.float32)
        [obj.begin(mode) for obj in metric_objs]
        for begin in range(0, num_images, minibatch_size):
            end = min(begin + minibatch_size, num_images)
            images, labels[begin:end] = dataset_obj.get_minibatch_np(end -
                                                                     begin)
            if mirror_augment:
                images = misc.apply_mirror_augment(images)
            if images.shape[1] == 1:
                images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
            [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()

    # Evaluate each network snapshot.
    for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)):
        prefix = 'network-snapshot-'
        postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(
            postfix)
        snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)])

        print('%-10d' % snapshot_kimg, end='')
        mode = 'fakes'
        [obj.begin(mode) for obj in metric_objs]
        time_begin = time.time()
        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(snapshot_pkl)
            for begin in range(0, num_images, minibatch_size):
                end = min(begin + minibatch_size, num_images)
                latents = misc.random_latents(end - begin, Gs)
                images = Gs.run(latents,
                                labels[begin:end],
                                num_gpus=config.num_gpus,
                                out_mul=127.5,
                                out_add=127.5,
                                out_dtype=np.uint8)
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
                [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()
    print()
Example #6
0
def calc_mnistrgb_histogram(run_id,
                            num_images=25600,
                            log='histogram.txt',
                            minibatch_size=256,
                            num_evals=10,
                            eval_reals=True,
                            final_only=False):

    # Load the classification network.
    # NOTE: The PKL can be downloaded from https://drive.google.com/open?id=0B4qLcYyJmiz0NHFULTdYc05lX0U
    net = network.load_mnist_classifier(
        os.path.join(config.data_dir,
                     '../networks/mnist_classifier_weights.pkl'))
    input_var = T.tensor4()
    output_expr = lasagne.layers.get_output(net,
                                            inputs=input_var,
                                            deterministic=True)
    classify_fn = theano.function([input_var], [output_expr])

    # Process folders
    print 'Processing directory %s' % (run_id)
    result_subdir = misc.locate_result_subdir(run_id)

    network_pkls = misc.list_network_pkls(result_subdir)
    misc.set_output_log_file(os.path.join(result_subdir, log))

    if final_only:
        network_pkls = [network_pkls[-1]]

    # Histogram calculation.
    def calc_histogram(images_all):
        scores = []
        divergences = []
        for i in range(num_evals):
            images = images_all[i * num_images:(i + 1) * num_images]
            model = [0.] * 1000
            for s in range(0, images.shape[0], minibatch_size):
                img = images[s:s + minibatch_size].reshape((-1, 1, 32, 32))
                res = np.asarray(classify_fn(img)[0])
                res = np.argmax(res, axis=1)
                res = res.reshape((-1, 3)) * np.asarray([[1, 10, 100]])
                res = np.sum(res, axis=1)
                for x in res:
                    model[int(x)] += 1.
            model = np.array([b / 25600. for b in model
                              if b > 0])  # remove empty buckets, normalize
            data = np.array([1. / 1000] *
                            len(model))  # corresponding ideal counts
            scores.append(len(model))
            divergences.append(np.sum(
                model *
                np.log(model /
                       data)))  # reverse KL? Metz et al. say KL(model || data)
        scores = np.asarray(scores, dtype=np.float32)
        return np.mean(scores), np.mean(divergences)

    # Load dataset.
    training_set, drange_orig = load_dataset_for_previous_run(result_subdir,
                                                              shuffle=False)
    reals, labels = training_set.get_random_minibatch(num_images * num_evals,
                                                      labels=True)

    # Evaluate reals.
    if eval_reals:
        print 'Evaluating histogram for reals...'
        time_begin = time.time()
        mean, kld = calc_histogram(reals)
        print 'Done in %s' % misc.format_time(time.time() - time_begin)
        print 'mean %-8.4f kld %-8.4f' % (mean, kld)

    # Evaluate each network snapshot.
    latents = None
    for network_idx, network_pkl in enumerate(network_pkls):
        print '%-32s' % os.path.basename(network_pkl),
        net = imgapi_load_net(run_id=result_subdir,
                              snapshot=network_pkl,
                              num_example_latents=num_images * num_evals)
        fakes = imgapi_generate_batch(net,
                                      net.example_latents,
                                      labels,
                                      minibatch_size=minibatch_size,
                                      convert_to_uint8=True)
        mean, kld = calc_histogram(fakes)
        print 'mean %-8.4f kld %-8.4f' % (mean, kld)
Example #7
0
def calc_sliced_wasserstein_scores(run_id,
                                   log='sliced-wasserstein.txt',
                                   resolution_min=16,
                                   resolution_max=1024,
                                   num_images=8192,
                                   nhoods_per_image=64,
                                   nhood_size=7,
                                   dir_repeats=1,
                                   dirs_per_repeat=147,
                                   minibatch_size=16):

    import sliced_wasserstein
    result_subdir = misc.locate_result_subdir(run_id)
    network_pkls = misc.list_network_pkls(result_subdir)
    misc.set_output_log_file(os.path.join(result_subdir, log))

    # Load dataset.
    print 'Loading dataset...'
    training_set, drange_orig = load_dataset_for_previous_run(result_subdir)
    assert training_set.shape[1] == 3  # RGB
    assert num_images % minibatch_size == 0

    # Select resolutions.
    resolution_full = training_set.shape[3]
    resolution_min = min(resolution_min, resolution_full)
    resolution_max = min(resolution_max, resolution_full)
    base_lod = int(np.log2(resolution_full)) - int(np.log2(resolution_max))
    resolutions = [
        2**i for i in xrange(int(np.log2(resolution_max)),
                             int(np.log2(resolution_min)) - 1, -1)
    ]

    # Collect descriptors for reals.
    print 'Extracting descriptors for reals...',
    time_begin = time.time()
    desc_real = [[] for res in resolutions]
    desc_test = [[] for res in resolutions]
    for minibatch_begin in xrange(0, num_images, minibatch_size):
        minibatch = training_set.get_random_minibatch(minibatch_size,
                                                      lod=base_lod)
        for lod, level in enumerate(
                sliced_wasserstein.generate_laplacian_pyramid(
                    minibatch, len(resolutions))):
            desc_real[lod].append(
                sliced_wasserstein.get_descriptors_for_minibatch(
                    level, nhood_size, nhoods_per_image))
            desc_test[lod].append(
                sliced_wasserstein.get_descriptors_for_minibatch(
                    level, nhood_size, nhoods_per_image))
    print 'done in %s' % misc.format_time(time.time() - time_begin)

    # Evaluate scores for reals.
    print 'Evaluating scores for reals...',
    time_begin = time.time()
    scores = []
    for lod, res in enumerate(resolutions):
        desc_real[lod] = sliced_wasserstein.finalize_descriptors(
            desc_real[lod])
        desc_test[lod] = sliced_wasserstein.finalize_descriptors(
            desc_test[lod])
        scores.append(
            sliced_wasserstein.sliced_wasserstein(desc_real[lod],
                                                  desc_test[lod], dir_repeats,
                                                  dirs_per_repeat))
    del desc_test
    print 'done in %s' % misc.format_time(time.time() - time_begin)

    # Print table header.
    print
    print '%-32s' % 'Case',
    for lod, res in enumerate(resolutions):
        print '%-12s' % ('%dx%d' % (res, res)),
    print 'Average'
    print '%-32s' % '---',
    for lod, res in enumerate(resolutions):
        print '%-12s' % '---',
    print '---'
    print '%-32s' % 'reals',
    for lod, res in enumerate(resolutions):
        print '%-12.6f' % scores[lod],
    print '%.6f' % np.mean(scores)

    # Process each network snapshot.
    for network_idx, network_pkl in enumerate(network_pkls):
        print '%-32s' % os.path.basename(network_pkl),
        net = imgapi_load_net(run_id=result_subdir,
                              snapshot=network_pkl,
                              num_example_latents=num_images,
                              random_seed=network_idx)

        # Extract descriptors for generated images.
        desc_fake = [[] for res in resolutions]
        for minibatch_begin in xrange(0, num_images, minibatch_size):
            latents = net.example_latents[minibatch_begin:minibatch_begin +
                                          minibatch_size]
            labels = net.example_labels[minibatch_begin:minibatch_begin +
                                        minibatch_size]
            minibatch = imgapi_generate_batch(net,
                                              latents,
                                              labels,
                                              minibatch_size=minibatch_size,
                                              convert_to_uint8=True)
            minibatch = sliced_wasserstein.downscale_minibatch(
                minibatch, base_lod)
            for lod, level in enumerate(
                    sliced_wasserstein.generate_laplacian_pyramid(
                        minibatch, len(resolutions))):
                desc_fake[lod].append(
                    sliced_wasserstein.get_descriptors_for_minibatch(
                        level, nhood_size, nhoods_per_image))

        # Evaluate scores.
        scores = []
        for lod, res in enumerate(resolutions):
            desc_fake[lod] = sliced_wasserstein.finalize_descriptors(
                desc_fake[lod])
            scores.append(
                sliced_wasserstein.sliced_wasserstein(desc_real[lod],
                                                      desc_fake[lod],
                                                      dir_repeats,
                                                      dirs_per_repeat))
        del desc_fake

        # Report results.
        for lod, res in enumerate(resolutions):
            print '%-12.6f' % scores[lod],
        print '%.6f' % np.mean(scores)
    print
    print 'Done.'