Example #1
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.image_list)
    image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
    output_dir = args.output_dir or f'results/ghfeat/{image_list_name}'
    logger = setup_logger(output_dir, 'extract_feature.log',
                          'inversion_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    G_args = EasyDict(func_name='training.networks_stylegan.G_synthesis')
    G_style_mod = tflib.Network('G_StyleMod',
                                resolution=image_size,
                                label_size=0,
                                **G_args)
    Gs_vars_pairs = {
        name: tflib.run(val)
        for name, val in Gs.components.synthesis.vars.items()
    }
    for g_name, g_val in G_style_mod.vars.items():
        tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    ghfeat = E.get_output_for(x, is_training=False)
    x_rec = G_style_mod.get_output_for(ghfeat, randomize_noise=False)

    # Load image list.
    logger.info(f'Loading image list.')
    image_list = []
    with open(args.image_list, 'r') as f:
        for line in f:
            image_list.append(line.strip())

    # Extract GH-Feat from images.
    logger.info(f'Start feature extraction.')
    headers = ['Name', 'Original Image', 'Encoder Output']
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=len(image_list),
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    names = ['' for _ in range(args.batch_size)]
    features = []
    for img_idx in tqdm(range(0, len(image_list), args.batch_size),
                        leave=False):
        # Load inputs.
        batch = image_list[img_idx:img_idx + args.batch_size]
        for i, image_path in enumerate(batch):
            image = resize_image(load_image(image_path),
                                 (image_size, image_size))
            images[i] = np.transpose(image, [2, 0, 1])
            names[i] = os.path.splitext(os.path.basename(image_path))[0]
        inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
        # Run encoder.
        outputs = sess.run([ghfeat, x_rec], {x: inputs})
        features.append(outputs[0][0:len(batch)])
        outputs[1] = adjust_pixel_range(outputs[1])
        for i, _ in enumerate(batch):
            image = np.transpose(images[i], [1, 2, 0])
            save_image(f'{output_dir}/{names[i]}_ori.png', image)
            save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i])
            visualizer.set_cell(i + img_idx, 0, text=names[i])
            visualizer.set_cell(i + img_idx, 1, image=image)
            visualizer.set_cell(i + img_idx, 2, image=outputs[1][i])

    # Save results.
    os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
    np.save(f'{output_dir}/ghfeat.npy', np.concatenate(features, axis=0))
    visualizer.save(f'{output_dir}/reconstruction.html')
Example #2
0
def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_manipulation'
    logger_name = f'{args.model_name}_manipulation_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    latent_codes = model.easy_synthesize(
        latent_codes=latent_codes,
        latent_space_type=args.latent_space_type,
        generate_style=False,
        generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    boundaries = parse_boundary_list(args.boundary_list_path)

    step = args.step + int(args.step % 2
                           == 0)  # Make sure it is an odd number.

    for boundary_info, boundary_path in boundaries.items():
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'

        if args.generate_html:
            viz_size = None if args.viz_size == 0 else args.viz_size
            visualizer = HtmlPageVisualizer(num_rows=total_num,
                                            num_cols=step + 1,
                                            viz_size=viz_size)
            visualizer.set_headers(
                [''] + [f'Step {i - step // 2}' for i in range(step // 2)] +
                ['Origin'] + [f'Step {i + 1}' for i in range(step // 2)])

        if args.generate_video:
            setup_images = model.easy_synthesize(
                latent_codes=latent_codes[args.latent_space_type],
                latent_space_type=args.latent_space_type)['image']
            fusion_kwargs = {
                'row': args.row,
                'col': args.col,
                'row_spacing': args.row_spacing,
                'col_spacing': args.col_spacing,
                'border_left': args.border_left,
                'border_right': args.border_right,
                'border_top': args.border_top,
                'border_bottom': args.border_bottom,
                'black_background': not args.white_background,
                'image_size': None if args.viz_size == 0 else args.viz_size,
            }
            setup_image = fuse_images(setup_images, **fusion_kwargs)
            video_writer = VideoWriter(os.path.join(
                work_dir, f'{prefix}_{args.video_name}'),
                                       frame_height=setup_image.shape[0],
                                       frame_width=setup_image.shape[1],
                                       fps=args.fps)

        logger.info(f'  Loading boundary.')
        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
            manipulate_layers = boundary_file['meta_data']['manipulate_layers']
        except ValueError:
            boundary = np.load(boundary_path)
            manipulate_layers = args.manipulate_layers
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if args.layerwise_manipulation and space_type != 'z':
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            if (not args.disable_manipulation_truncation
                ) and space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'
        else:
            if args.layerwise_manipulation:
                logger.warning(f'  Skip layer-wise manipulation for boundary '
                               f'`{boundary_name}` from Z space. Traditional '
                               f'manipulation is used instead.')
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            strength = 1.0

        codes = manipulate(latent_codes=latent_codes[space_type],
                           boundary=boundary,
                           start_distance=args.start_distance,
                           end_distance=args.end_distance,
                           step=step,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=manipulate_layers,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength)
        np.save(
            os.path.join(work_dir, f'{prefix}_manipulated_{space_type}.npy'),
            codes)

        logger.info(f'  Start manipulating.')
        for s in tqdm(range(step), leave=False):
            images = model.easy_synthesize(
                latent_codes=codes[:,
                                   s], latent_space_type=space_type)['image']
            if args.generate_video:
                video_writer.write(fuse_images(images, **fusion_kwargs))
            for n, image in enumerate(images):
                if args.save_raw_synthesis:
                    save_image(
                        os.path.join(work_dir,
                                     f'{prefix}_{n:05d}_{s:03d}.jpg'), image)
                if args.generate_html:
                    visualizer.set_cell(n, s + 1, image=image)
                    if s == 0:
                        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

        if args.generate_html:
            visualizer.save(
                os.path.join(work_dir, f'{prefix}_{args.html_name}'))