Example #1
0
def make_generator():
    """Make a simple func to convert latent vecs -> images"""
    with util.open_url(FACE_MODEL_URL, cache_dir=config.cache_dir) as f:
        _G, _D, Gs = pickle.load(f)
        fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
        def gen(latents):
            images = Gs.run(latents, None, truncation_psi=0.7, 
                randomize_noise=True, 
                output_transform=fmt)
            return images
        return gen
Example #2
0
 def build(self, input_shape):
     with util.open_url(FACE_MODEL_URL, cache_dir=config.cache_dir) as f:
         _G, _D, Gs = pickle.load(f)
     self.Gs = Gs.clone()
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Output directories setting
    parser.add_argument('src_dir', help='Directory with images for encoding')
    parser.add_argument('generated_images_dir', help='Directory for storing generated images')
    parser.add_argument('guessed_images_dir', help='Directory for storing initially guessed images')
    parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')

    # General params
    parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
    parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
    parser.add_argument('--use_resnet', default=True, help='Use pretrained ResNet for approximating dlatents', type=lambda x: (str(x).lower() == 'true'))

    # Perceptual model params
    parser.add_argument('--iterations', default=100, help='Number of optimization steps for each batch', type=int)
    parser.add_argument('--lr', default=0.02, help='Learning rate for perceptual model', type=float)
    parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
    parser.add_argument('--decay_steps', default=10, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
    parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
    parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)

    # Loss function options
    parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
    parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
    parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
    parser.add_argument('--use_mssim_loss', default=100, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float)
    parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float)
    parser.add_argument('--use_l1_penalty', default=1, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)

    # Generator params
    parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=lambda x: (str(x).lower() == 'true'))
    parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=lambda x: (str(x).lower() == 'true'))
    parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)

    # Masking params
    parser.add_argument('--mask_dir', default='masks/latent_interpolation', help='Directory for storing optional masks')
    parser.add_argument('--face_mask', default=False, help='Generate a mask for predicting only the face area', type=lambda x: (str(x).lower() == 'true'))
    parser.add_argument('--use_grabcut', default=True, help='Use grabcut algorithm on the face mask to better segment the foreground', type=lambda x: (str(x).lower() == 'true'))
    parser.add_argument('--scale_mask', default=1.5, help='Look over a wider section of foreground for grabcut', type=float)

    args, other_args = parser.parse_known_args()

    args.decay_steps *= 0.01 * args.iterations  # Calculate steps as a percent of total iterations

    ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
    ref_images = sorted(list(filter(os.path.isfile, ref_images)))

    if len(ref_images) == 0:
        raise Exception('%s is empty' % args.src_dir)

    # Create output directories
    os.makedirs('data', exist_ok=True)
    os.makedirs(args.generated_images_dir, exist_ok=True)
    os.makedirs(args.guessed_images_dir, exist_ok=True)
    os.makedirs(args.dlatent_dir, exist_ok=True)
    if args.face_mask:
        os.makedirs(args.mask_dir, exist_ok=True)

    # Initialize generator
    tflib.init_tf()
    with open_url(url_styleGAN, cache_dir='cache') as f:
        generator_network, discriminator_network, Gs_network = pickle.load(f)

    generator = Generator(model=Gs_network,
                          batch_size=args.batch_size,
                          clipping_threshold=args.clipping_threshold,
                          tiled_dlatent=args.tile_dlatents,
                          model_res=args.model_res,
                          randomize_noise=args.randomize_noise)

    # Initialize perceptual model
    perc_model = None
    if args.use_lpips_loss > 1e-7:
        with open_url(url_VGG_perceptual, cache_dir='cache') as f:
            perc_model = pickle.load(f)
    perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
    perceptual_model.build_perceptual_model(generator)

    # Initialize ResNet model
    resnet_model = None
    if args.use_resnet:
        print("\nLoading ResNet Model:")
        resnet_model_fn = 'data/finetuned_resnet.h5'
        gdown.download(url_resnet, resnet_model_fn, quiet=True)
        resnet_model = load_model(resnet_model_fn)

    # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
    for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images) // args.batch_size):
        names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
        perceptual_model.set_reference_images(images_batch)

        # predict initial dlatents with ResNet model
        if resnet_model is not None:
            dlatents = resnet_model.predict(preprocess_input(load_images(images_batch, image_size=args.resnet_image_size)))
            generator.set_dlatents(dlatents)

        # Generate and save initially guessed images
        initial_dlatents = generator.get_dlatents()
        initial_images = generator.generate_images()
        for img_array, dlatent, img_name in zip(initial_images, initial_dlatents, names):
            img = PIL.Image.fromarray(img_array, 'RGB')
            img.save(os.path.join(args.guessed_images_dir, f'{img_name}.png'), 'PNG')

        # Optimization process to find best latent vectors
        op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations)
        progress_bar = tqdm(op, leave=False, total=args.iterations)
        best_loss = None
        best_dlatent = None
        for loss_dict in progress_bar:
            progress_bar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
            if best_loss is None or loss_dict["loss"] < best_loss:
                best_loss = loss_dict["loss"]
                best_dlatent = generator.get_dlatents()
            generator.stochastic_clip_dlatents()
        print(" ".join(names), " Loss {:.4f}".format(best_loss))

        # Save found dlatents
        generator.set_dlatents(best_dlatent)
        generated_dlatents = generator.get_dlatents()
        for dlatent, img_name in zip(generated_dlatents, names):
            np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)
        generator.reset_dlatents()

    # Concatenate and save dlalents vectors
    list_dlatents = sorted(os.listdir(args.dlatent_dir))
    final_w_vectors = np.array([np.load(args.dlatent_dir + dlatent) for dlatent in list_dlatents])
    np.save(os.path.join(args.dlatent_dir, 'output_vectors.npy'), final_w_vectors)

    # Perform face morphing by interpolating the latent space
    w1, w2 = create_morphing_lists(final_w_vectors)
    ref_images_1, ref_images_2 = create_morphing_lists(ref_images)
    for i in range(len(ref_images_1)):
        avg_w_vector = (0.5 * (w1[i] + w2[i])).reshape((-1, 18, 512))
        generator.set_dlatents(avg_w_vector)
        img_array = generator.generate_images()[0]
        img = PIL.Image.fromarray(img_array, 'RGB')
        img_name = os.path.splitext(os.path.basename(ref_images_1[i]))[0] + '_vs_' + os.path.splitext(os.path.basename(ref_images_2[i]))[0]
        img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
    generator.reset_dlatents()
Example #4
0
if args.use_fp16:
    K.set_floatx('float16')
    K.set_epsilon(1e-4)

tflib.init_tf()

model = get_resnet_model(args.model_path,
                         model_res=args.model_res,
                         depth=args.model_depth,
                         size=args.model_size,
                         activation=args.activation,
                         optimizer=args.optimizer,
                         loss=args.loss)

with open_url(args.model_url, cache_dir='cache') as f:
    generator_network, discriminator_network, Gs_network = pickle.load(f)


def load_Gs():
    return Gs_network


if args.freeze_first:
    model.layers[1].trainable = False
    model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)

model.summary()

if args.freeze_first:  # run a training iteration first while pretrained model is frozen, then unfreeze.
    finetune_resnet(model,