Example #1
0
def main():
    t0 = time.time()
    print('t0:', t0)

    # Initialize TensorFlow.
    tflib.init_tf()  # 0.82s

    print('t1:', time.time() - t0)

    # Load pre-trained network.
    with open('./models/stylegan2-ffhq-config-f.pkl', 'rb') as f:
        print('t2:', time.time() - t0)

        _G, _D, Gs = pickle.load(f)  # 13.09s

        print('t3:', time.time() - t0)

    with open('./models/vgg16_zhang_perceptual.pkl', 'rb') as f:
        lpips = pickle.load(f)

        print('t4:', time.time() - t0)

    proj = Projector()
    proj.set_network(Gs, lpips)

    image = PIL.Image.open('./images/example.png')
    #image = image.resize((Di.input_shape[2], Di.input_shape[3]), PIL.Image.ANTIALIAS)
    image_array = np.array(image).swapaxes(0, 2).swapaxes(1, 2)
    image_array = misc.adjust_dynamic_range(image_array, [0, 255], [-1, 1])

    print('t5:', time.time() - t0)

    proj.start([image_array])
    for step in proj.runSteps(1000):
        print('\rstep: %d' % step, end='', flush=True)
        if step % 10 == 0:
            results = proj.get_images()
            pilImage = misc.convert_to_pil_image(
                misc.create_image_grid(results), drange=[-1, 1])
            pilImage.save('./images/project-%d.png' % step)

    print('t6:', time.time() - t0)

    dlatents = proj.get_dlatents()
    noises = proj.get_noises()
    print('dlatents:', dlatents.shape)
    print('noises:', len(noises), noises[0].shape, noises[-1].shape)
    # Load model
    g = Generator(1024, 512, 8, pretrained=True).to(device).train()
    for param in g.parameters():
        param.requires_grad = False

    proj = Projector(g)

    for i, file in tqdm(enumerate(sorted(image_files))):
        print('Projecting {}'.format(file))

        # Load image
        target_image = Image.open(file)
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        target_image = transform(target_image).to(device)

        # Run projector
        proj.run(target_image)

        # Collect results
        generated = proj.get_images()
        latents = proj.get_latents()

        # Save results
        save_str = target_dir + file.split('/')[-1].split('.')[0]
        print('Saving {}'.format(save_str + '_p.png'))
        save_image(generated, save_str + '_p.png', normalize=True)
        torch.save(latents.detach().cpu(), save_str + '.pt')