Ejemplo n.º 1
0
def project_image(proj, targets, img_num, num_snapshots):
    snapshot_steps = set(proj.num_steps - np.linspace(
        0, proj.num_steps, num_snapshots, endpoint=False, dtype=int))
    #misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1])
    proj.start(targets)
    while proj.get_cur_step() < proj.num_steps:
        print('\rProjecting image %d: %d / %d ... ' %
              (img_num, proj.get_cur_step(), proj.num_steps),
              end='',
              flush=True)
        proj.step()
        if proj.get_cur_step() == proj.num_steps:
            #misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1])

            imreal = np.array(misc.convert_to_pil_image(targets[0]))
            improj = np.array(misc.convert_to_pil_image(proj.get_images()[0]))
            imreal = normalize_img(imreal)
            improj = normalize_img(improj)
            temp_mse = mse(imreal, improj)
            temp_ssim = ssim(imreal, improj, multichannel=True)
            # print(temp_ssim, temp_mse , flush=True)

    # print(improj)
    # print(imreal)
    # misc.convert_to_pil_image(targets[0]).save("test_img_real.png")
    # misc.convert_to_pil_image(proj.get_images()[0]).save("test_img.png")
    return improj, temp_ssim, temp_mse
    print('\r%-30s\r' % '', end='', flush=True)
Ejemplo n.º 2
0
def gen_images(latents,
               truncation_psi_val,
               outfile=None,
               display=False,
               labels=None,
               randomize_noise=False,
               is_validation=True,
               network=None,
               numpy=False):
    if outfile:
        Path(outfile).parent.mkdir(exist_ok=True, parents=True)

    if network is None:
        network = Gs
    n = latents.shape[0]
    grid_size = get_grid_size(n)
    drange_net = [-1, 1]
    with tflex.device('/gpu:0'):
        result = network.run(latents,
                             labels,
                             truncation_psi_val=truncation_psi_val,
                             is_validation=is_validation,
                             randomize_noise=randomize_noise,
                             minibatch_size=sched.minibatch_gpu)
        result = result[:, 0:3, :, :]
        img = misc.convert_to_pil_image(
            misc.create_image_grid(result, grid_size), drange_net)
        if outfile is not None:
            img.save(outfile)
        if display:
            f = BytesIO()
            img.save(f, 'png')
            IPython.display.display(IPython.display.Image(data=f.getvalue()))
    return result if numpy else img
Ejemplo n.º 3
0
    def gen():
        proj = loadProjector()
        #proj.regularize_noise_weight = regularizeNoiseWeight
        proj.start([image_array])
        for step in proj.runSteps(steps):
            print('\rProjecting: %d / %d' % (step, steps), end='', flush=True)

            if step % yieldInterval == 0:
                dlatents = proj.get_dlatents()
                images = proj.get_images()
                pilImage = misc.convert_to_pil_image(
                    misc.create_image_grid(images), drange=[-1, 1])

                fp = io.BytesIO()
                pilImage.save(fp, PIL.Image.registered_extensions()['.png'])

                imgUrl = 'data:image/png;base64,%s' % base64.b64encode(
                    fp.getvalue()).decode('ascii')

                #latentsList = list(dlatents.reshape((-1, dlatents.shape[2])))
                #latentCodes = list(map(lambda latents: latentCode.encodeFloat32(latents).decode('ascii'), latentsList))
                latentCodes = latentCode.encodeFixed16(
                    dlatents.flatten()).decode('ascii')

                yield json.dumps(
                    dict(step=step, img=imgUrl,
                         latentCodes=latentCodes)) + '\n\n'

        print('\rProjecting finished.%s' % (' ' * 8))
Ejemplo n.º 4
0
def get_images(tags, seed=0, mu=0, sigma=0, truncation=None):
    print("Generating mammos...")

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
                                      nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation is not None:
        Gs_kwargs.truncation_psi = truncation
    rnd = np.random.RandomState(seed)

    all_seeds = [seed] * batch_size
    all_z = np.stack([
        np.random.RandomState(seed).randn(*tflex.Gs.input_shape[1:])
        for seed in all_seeds
    ])  # [minibatch, component]
    print(all_z.shape)

    drange_net = [-1, 1]
    with tflex.device('/gpu:0'):
        result = tflex.Gs.run(all_z,
                              None,
                              is_validation=True,
                              randomize_noise=False,
                              minibatch_size=sched.minibatch_gpu)
        if result.shape[1] > 3:
            final = result[:, 3, :, :]
        else:
            final = None
        result = result[:, 0:3, :, :]
        img = misc.convert_to_pil_image(misc.create_image_grid(result, (1, 1)),
                                        drange_net)
        img.save('mammos.png')
        return result, img
Ejemplo n.º 5
0
 async def send_picture(channel, image, kind='png', name='test', text=None):
     img = misc.convert_to_pil_image(image, [-1, 1])
     f = BytesIO()
     img.save(f, kind)
     f.seek(0)
     picture = discord.File(f)
     picture.filename = name + '.' + kind
     await channel.send(content=text, file=picture)
def get_projected_images(proj, targets, num_snapshots):
    snapshot_steps = set(proj.num_steps - np.linspace(
        0, proj.num_steps, num_snapshots, endpoint=False, dtype=int))
    proj.start(targets)
    image_array = []
    while proj.get_cur_step() < proj.num_steps:
        print('%d / %d ... ' % (proj.get_cur_step(), proj.num_steps),
              flush=True)
        proj.step()
        if proj.get_cur_step() in snapshot_steps:
            image_array.append(
                misc.convert_to_pil_image(proj.get_images()[0], drange=[-1,
                                                                        1]))
            #misc.save_image_grid(proj.get_images(),  './step%04d.png' % proj.get_cur_step(), drange=[-1,1])

    image_array.append(
        misc.convert_to_pil_image(proj.get_images()[0], drange=[-1, 1]))
    return (image_array)
Ejemplo n.º 7
0
def generate():
    latentsStr = flask.request.args.get("latents")
    latentsStrX = flask.request.args.get("xlatents")
    psi = float(flask.request.args.get("psi", 0.5))
    # use_noise = bool(flask.request.args.get('use_noise', True))
    randomize_noise = int(flask.request.args.get("randomize_noise", 0))
    fromW = int(flask.request.args.get("fromW", 0))

    global model_name
    global g_Session
    global g_dLatentsIn
    # print('g_Session.1:', g_Session)

    fetched_model_name = flask.request.args.get("model_name", "ffhq")
    if model_name != fetched_model_name:
        model_name = fetched_model_name
    gs, synthesis = loadGs()

    latent_len = gs.input_shape[1]
    if latentsStrX:
        latents = latentCode.decodeFixed16(latentsStrX, g_dLatentsIn.shape[0])
    else:
        latents = latentCode.decodeFloat32(latentsStr, latent_len)

    t0 = time.time()

    # Generate image.
    fmt = dict(func=dnnlib.tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    with g_Session.as_default():
        if fromW != 0:
            # print('latentsStr:', latentsStr)
            # print('shapes:', g_dLatentsIn.shape, latents.shape)

            if latents.shape[0] < g_dLatentsIn.shape[0]:
                latents = np.tile(latents,
                                  g_dLatentsIn.shape[0] // latents.shape[0])
            images = dnnlib.tflib.run(synthesis, {g_dLatentsIn: latents})
            image = misc.convert_to_pil_image(misc.create_image_grid(images),
                                              drange=[-1, 1])
        else:
            latents = latents.reshape([1, latent_len])
            images = gs.run(
                latents,
                None,
                truncation_psi=psi,
                randomize_noise=randomize_noise != 0,
                output_transform=fmt,
            )
            image = PIL.Image.fromarray(images[0], "RGB")

    print("generation cost:", time.time() - t0)

    # encode to PNG
    fp = io.BytesIO()
    image.save(fp, PIL.Image.registered_extensions()[".png"])

    return flask.Response(fp.getvalue(), mimetype="image/png")
Ejemplo n.º 8
0
def generate_images(network_pkl, seed_z, seeds, truncation_psi):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    noise_vars = [
        var for name, var in Gs.components.synthesis.vars.items()
        if name.startswith('noise')
    ]

    Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    rnd = np.random.RandomState(seed_z)
    z = rnd.randn(1, *Gs.input_shape[1:])
    img = []
    idx = [0, 2, 4, 6, 8, 10, 11]
    for layer_idx in idx:
        print('Generating image for %d (%d/%d) ...' %
              (layer_idx, layer_idx, len(noise_vars)))
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]
        zero_vars = noise_vars[layer_idx:]

        if len(zero_vars) != 0:
            tflib.set_vars({
                var: np.zeros(var.shape.as_list(), dtype=np.float32)
                for var in zero_vars
            })  # [height, width]
        images = Gs.run(z, None,
                        **Gs_kwargs)  # [minibatch, height, width, channel]
        img.append(images)
        # PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))
        misc.convert_to_pil_image(images[0], drange=[-1, 1]).save(
            dnnlib.make_run_dir_path('seed%04d.png' % layer_idx))
    img = np.concatenate(img, 0)
    misc.save_image_grid(img,
                         dnnlib.make_run_dir_path('img.png'),
                         drange=[-1, 1],
                         grid_size=[7, 1])
Ejemplo n.º 9
0
def main():
    t0 = time.time()
    print('t0:', t0)

    # Initialize TensorFlow.
    tflib.init_tf()  # 0.82s

    print('t1:', time.time() - t0)

    # Load pre-trained network.
    with open('./models/stylegan2-ffhq-config-f.pkl', 'rb') as f:
        print('t2:', time.time() - t0)

        _G, _D, Gs = pickle.load(f)  # 13.09s

        print('t3:', time.time() - t0)

    with open('./models/vgg16_zhang_perceptual.pkl', 'rb') as f:
        lpips = pickle.load(f)

        print('t4:', time.time() - t0)

    proj = Projector()
    proj.set_network(Gs, lpips)

    image = PIL.Image.open('./images/example.png')
    #image = image.resize((Di.input_shape[2], Di.input_shape[3]), PIL.Image.ANTIALIAS)
    image_array = np.array(image).swapaxes(0, 2).swapaxes(1, 2)
    image_array = misc.adjust_dynamic_range(image_array, [0, 255], [-1, 1])

    print('t5:', time.time() - t0)

    proj.start([image_array])
    for step in proj.runSteps(1000):
        print('\rstep: %d' % step, end='', flush=True)
        if step % 10 == 0:
            results = proj.get_images()
            pilImage = misc.convert_to_pil_image(
                misc.create_image_grid(results), drange=[-1, 1])
            pilImage.save('./images/project-%d.png' % step)

    print('t6:', time.time() - t0)

    dlatents = proj.get_dlatents()
    noises = proj.get_noises()
    print('dlatents:', dlatents.shape)
    print('noises:', len(noises), noises[0].shape, noises[-1].shape)
Ejemplo n.º 10
0
def write_video_frame(proj, video_out):
    img = proj.get_images()[0]
    img = misc.convert_to_pil_image(img, drange=[-1, 1])
    video_frame = img # .resize((512, 512))
    video_out.write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
Ejemplo n.º 11
0
def generate_images(network_pkl,
                    network_G_pkl,
                    n_imgs,
                    model_type,
                    n_discrete,
                    n_continuous,
                    use_std_in_m=None,
                    latent_type='uniform',
                    n_samples_per=10):
    print('Loading networks from "%s"...' % network_pkl)
    tflib.init_tf()
    if model_type == 'hd_dis_model_with_cls':
        # _G, _D, I, Gs = misc.load_pkl(network_pkl)
        I, M, Is, I_info = misc.load_pkl(network_pkl)
    else:
        # _G, _D, Gs = misc.load_pkl(network_pkl)
        I, M, Is = misc.load_pkl(network_pkl)

    # Load pretrained GAN
    _G, _D, Gs = misc.load_pkl(network_G_pkl)

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
                                      nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False

    for idx in range(n_imgs):
        print('Generating image %d/%d ...' % (idx, n_imgs))

        if n_discrete == 0:
            grid_labels = np.zeros([n_continuous * n_samples_per, 0],
                                   dtype=np.float32)
        else:
            grid_labels = np.zeros(
                [n_discrete * n_continuous * n_samples_per, 0],
                dtype=np.float32)

        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete,
            n_continuous,
            n_samples_per,
            _G,
            grid_labels,
            latent_type=latent_type)
        prior_traj_latents = M.run(grid_latents,
                                   is_validation=True,
                                   minibatch_size=4)
        if use_std_in_m is not None:
            prior_traj_latents = prior_traj_latents[:, :prior_traj_latents.
                                                    shape[1] // 2]
        grid_fakes = Gs.run(prior_traj_latents,
                            grid_labels,
                            is_validation=True,
                            minibatch_size=4,
                            randomize_noise=False)
        print(grid_fakes.shape)
        misc.save_image_grid(grid_fakes,
                             dnnlib.make_run_dir_path('img_%04d.png' % idx),
                             drange=[-1, 1],
                             grid_size=grid_size)
        frames = []
        grid_fakes = np.reshape(grid_fakes, [
            n_continuous, n_samples_per, grid_fakes.shape[1],
            grid_fakes.shape[2], grid_fakes.shape[3]
        ])
        for i in range(n_samples_per):
            to_concat = [grid_fakes[j, i] for j in range(n_continuous)]
            to_concat = tuple(to_concat)
            grid_fake_pil = misc.convert_to_pil_image(
                np.concatenate(to_concat, axis=2))
            frames.append(grid_fake_pil)
        frames[0].save(dnnlib.make_run_dir_path('latents_trav_%04d.gif' % idx),
                       format='GIF',
                       append_images=frames[1:],
                       save_all=True,
                       duration=100,
                       loop=0)
Ejemplo n.º 12
0
def generate_image_pairs(network_pkl,
                         n_imgs,
                         model_type,
                         n_discrete,
                         n_continuous,
                         result_dir,
                         batch_size=10,
                         return_atts=True,
                         latent_type='onedim',
                         act_mask_ls=None):
    print('Loading networks from "%s"...' % network_pkl)
    tflib.init_tf()
    if (model_type == 'info_gan') or (model_type == 'vc_gan_with_vc_head'):
        _G, _D, I, Gs = misc.load_pkl(network_pkl)
    else:
        _G, _D, Gs = misc.load_pkl(network_pkl)

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # _G, _D, Gs = pretrained_networks.load_networks(network_pkl)

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.randomize_noise = False
    Gs_kwargs.return_atts = return_atts

    n_batches = n_imgs // batch_size

    if act_mask_ls is None:
        act_mask_ls = np.arange(n_continuous)
    n_act = len(act_mask_ls) # e.g. act_mask_ls: [0,2,3,5]
    act_mask_dup_array = np.tile(np.array(act_mask_ls)[np.newaxis, ...], [batch_size, 1])
    for i in range(n_batches):
        print('Generating image pairs %d/%d ...' % (i, n_batches))
        grid_labels = np.zeros([batch_size, 0], dtype=np.float32)

        if n_discrete > 0:
            cat_dim = np.random.randint(0, n_discrete, size=[batch_size])
            cat_onehot = np.zeros((batch_size, n_discrete))
            cat_onehot[np.arange(cat_dim.size), cat_dim] = 1

        # z_1 = np.random.normal(size=[batch_size, n_continuous])
        # z_2 = np.random.normal(size=[batch_size, n_continuous])
        # if latent_type == 'onedim':
            # delta_dim = np.random.randint(0, n_continuous, size=[batch_size])
            # delta_onehot = np.zeros((batch_size, n_continuous))
            # delta_onehot[np.arange(delta_dim.size), delta_dim] = 1
            # z_2 = np.where(delta_onehot > 0, z_2, z_1)


        # New
        z_1 = np.random.normal(size=[batch_size, n_continuous])
        z_2 = np.random.normal(size=[batch_size, n_continuous])
        if latent_type == 'onedim':
            delta_dim_act = np.random.randint(0, n_act, size=[batch_size])
            delta_dim = act_mask_dup_array[np.arange(batch_size), delta_dim_act]
            delta_onehot = np.zeros((batch_size, n_continuous))
            delta_onehot[np.arange(delta_dim.size), delta_dim] = 1
            z_2 = np.where(delta_onehot > 0, z_2, z_1)
        # print('z1:', z_1)
        # print('z2:', z_2)
        # pdb.set_trace()

        delta_z = z_1 - z_2

        if i == 0:
            labels = delta_z
        else:
            labels = np.concatenate([labels, delta_z], axis=0)

        if n_discrete > 0:
            z_1 = np.concatenate((cat_onehot, z_1), axis=1)
            z_2 = np.concatenate((cat_onehot, z_2), axis=1)

        fakes_1 = get_return_v(
            Gs.run(z_1,
                   grid_labels,
                   is_validation=True,
                   minibatch_size=batch_size,
                   **Gs_kwargs), 1)
        fakes_2 = get_return_v(
            Gs.run(z_2,
                   grid_labels,
                   is_validation=True,
                   minibatch_size=batch_size,
                   **Gs_kwargs), 1)
        print('fakes_1.shape:', fakes_1.shape)
        print('fakes_2.shape:', fakes_2.shape)

        for j in range(fakes_1.shape[0]):
            pair_np = np.concatenate([fakes_1[j], fakes_2[j]], axis=2)
            img = misc.convert_to_pil_image(pair_np, [-1, 1])
            # pair_np = (pair_np * 255).astype(np.uint8)
            # img = Image.fromarray(pair_np)
            img.save(
                os.path.join(result_dir,
                             'pair_%06d.jpg' % (i * batch_size + j)))
    np.save(os.path.join(result_dir, 'labels.npy'), labels)
Ejemplo n.º 13
0
def generate_images(network_pkl,
                    seeds,
                    truncation_psi,
                    data_dir=None,
                    dataset_name=None,
                    model=None):
    G_args = EasyDict(func_name='training.' + model + '.G_main')
    dataset_args = EasyDict(tfrecord_dir=dataset_name)
    G_args.fmap_base = 8 << 10
    tflib.init_tf()
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    print('Constructing networks...')
    Gs = tflib.Network('G',
                       num_channels=training_set.shape[0],
                       resolution=training_set.shape[1],
                       label_size=training_set.label_size,
                       **G_args)
    print('Loading networks from "%s"...' % network_pkl)
    _, _, _Gs = pretrained_networks.load_networks(network_pkl)
    Gs.copy_vars_from(_Gs)
    noise_vars = [
        var for name, var in Gs.components.synthesis.vars.items()
        if name.startswith('noise')
    ]

    Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' %
              (seed, seed_idx, len(seeds)))
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:])  # [minibatch, component]
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]
        images, x_v, n_v, m_v = Gs.run(
            z, None, **Gs_kwargs)  # [minibatch, height, width, channel]

        print(images.shape, n_v.shape, x_v.shape, m_v.shape)
        misc.convert_to_pil_image(images[0], drange=[-1, 1]).save(
            dnnlib.make_run_dir_path('seed%04d.png' % seed))
        misc.save_image_grid(adjust_range(n_v),
                             dnnlib.make_run_dir_path('seed%04d-nv.png' %
                                                      seed),
                             drange=[-1, 1])
        print(np.linalg.norm(x_v - m_v))
        misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-xv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-mv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(x_v, 'cat')),
                             dnnlib.make_run_dir_path('seed%04d-xvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ss')),
                             dnnlib.make_run_dir_path('seed%04d-mvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')),
                             dnnlib.make_run_dir_path('seed%04d-fmvs.png' %
                                                      seed),
                             drange=[-1, 1])
Ejemplo n.º 14
0
def gen_disc_test(proj,
                  D,
                  dataset_name,
                  data_dir,
                  num_snapshots=2,
                  queue=None):

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=2,
                                       repeat=False,
                                       shuffle_mb=0)

    all_ssim = []
    all_mse = []
    labels = []
    discrim = []
    image_idx = 0
    all_real_img = []
    all_proj_img = []

    all_blink_detects = []

    while (True):
        #print('Projecting image %d ...' % (image_idx), flush=True)
        # if image_idx == 10 :
        #     break
        try:
            images, label = dataset_obj.get_minibatch_np(1)
            # print(label)
            if not len(label[0]) == 1:
                label, name = label[0]
                label = np.array([[label]])
            else:
                name = None
            labels.append(label)
        except:
            break
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        img, temp_ssim, temp_mse = project_image(proj,
                                                 targets=images,
                                                 img_num=image_idx,
                                                 num_snapshots=num_snapshots)
        all_ssim.append(temp_ssim)
        all_mse.append(temp_mse)
        all_real_img.append(np.array(misc.convert_to_pil_image(images[0])))
        all_proj_img.append(img)
        if name is not None:
            all_blink_detects.append(check_blink(parse_num(name, data_dir)))
        test = images
        # test = np.array(normalize_img(images))

        discrim.append(D.run(test, None)[0][0])

        image_idx += 1
    if queue is not None:
        queue.put((all_ssim, all_mse, labels, discrim, all_real_img,
                   all_proj_img, all_blink_detects))
    return [
        all_ssim, all_mse, labels, discrim, all_real_img, all_proj_img,
        all_blink_detects
    ]
def generate_image_pairs(network_pkl,
                         n_imgs,
                         model_type,
                         n_discrete,
                         n_continuous,
                         result_dir,
                         batch_size=10,
                         latent_type='onedim'):
    print('Loading networks from "%s"...' % network_pkl)
    tflib.init_tf()
    if (model_type == 'info_gan') or (model_type == 'vc_gan_with_vc_head'):
        _G, _D, I, Gs = misc.load_pkl(network_pkl)
    else:
        _G, _D, Gs = misc.load_pkl(network_pkl)

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # _G, _D, Gs = pretrained_networks.load_networks(network_pkl)

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.randomize_noise = False

    n_batches = n_imgs // batch_size

    for i in range(n_batches):
        print('Generating image pairs %d/%d ...' % (i, n_batches))
        grid_labels = np.zeros([batch_size, 0], dtype=np.float32)

        if n_discrete > 0:
            cat_dim = np.random.randint(0, n_discrete, size=[batch_size])
            cat_onehot = np.zeros((batch_size, n_discrete))
            cat_onehot[np.arange(cat_dim.size), cat_dim] = 1

        z_1 = np.random.uniform(low=-2,
                                high=2,
                                size=[batch_size, n_continuous])
        z_2 = np.random.uniform(low=-2,
                                high=2,
                                size=[batch_size, n_continuous])
        if latent_type == 'onedim':
            delta_dim = np.random.randint(0, n_continuous, size=[batch_size])
            delta_onehot = np.zeros((batch_size, n_continuous))
            delta_onehot[np.arange(delta_dim.size), delta_dim] = 1
            z_2 = np.where(delta_onehot > 0, z_2, z_1)
        delta_z = z_1 - z_2

        if i == 0:
            labels = delta_z
        else:
            labels = np.concatenate([labels, delta_z], axis=0)

        if n_discrete > 0:
            z_1 = np.concatenate((cat_onehot, z_1), axis=1)
            z_2 = np.concatenate((cat_onehot, z_2), axis=1)

        fakes_1 = Gs.run(z_1,
                         grid_labels,
                         is_validation=True,
                         minibatch_size=batch_size,
                         **Gs_kwargs)
        fakes_2 = Gs.run(z_2,
                         grid_labels,
                         is_validation=True,
                         minibatch_size=batch_size,
                         **Gs_kwargs)
        print('fakes_1.shape:', fakes_1.shape)
        print('fakes_2.shape:', fakes_2.shape)

        for j in range(fakes_1.shape[0]):
            pair_np = np.concatenate([fakes_1[j], fakes_2[j]], axis=2)
            img = misc.convert_to_pil_image(pair_np, [-1, 1])
            # pair_np = (pair_np * 255).astype(np.uint8)
            # img = Image.fromarray(pair_np)
            img.save(
                os.path.join(result_dir,
                             'pair_%06d.jpg' % (i * batch_size + j)))
    np.save(os.path.join(result_dir, 'labels.npy'), labels)