def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_rescore'
    logger_name = f'{args.model_name}_rescore_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if args.num <= 0:
        raise ValueError(f'Argument `num` should be specified as a positive '
                         f'number, but `{args.num}` received!')
    latent_codes = model.easy_sample(num=args.num, latent_space_type='z')
    latent_codes = model.easy_synthesize(latent_codes=latent_codes,
                                         latent_space_type='z',
                                         generate_style=False,
                                         generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    logger.info(f'Initializing predictor.')
    predictor = build_predictor(args.predictor_name)

    boundaries = parse_boundary_list(args.boundary_list_path)

    logger.info(f'========================================')
    logger.info(f'Rescoring.')
    score_changing = []
    for boundary_info, boundary_path in boundaries.items():
        logger.info(f'----------------------------------------')
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'
        attr_idx = predictor.attribute_name_to_idx[boundary_name]

        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
        except ValueError:
            boundary = np.load(boundary_path)

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if space_type == 'z':
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            num_layers = 0
            strength = 1.0
        else:
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            num_layers = model.num_layers if args.layerwise_rescoring else 0
            if space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'

        codes = []
        codes.append(latent_codes[space_type][:, np.newaxis])
        for l in range(-1, num_layers):
            codes.append(
                manipulate(latent_codes[space_type],
                           boundary,
                           start_distance=2.0,
                           end_distance=2.0,
                           step=1,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=None if l < 0 else l,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength))
        codes = np.concatenate(codes, axis=1)

        scores = []
        for i in tqdm(range(args.num), leave=False):
            images = model.easy_synthesize(latent_codes=codes[i],
                                           latent_space_type=space_type,
                                           generate_style=False,
                                           generate_image=True)['image']
            scores.append(
                predictor.easy_predict(images)['attribute'][:, attr_idx])
        scores = np.stack(scores, axis=0)
        np.save(os.path.join(work_dir, f'{prefix}_scores.npy'), scores)

        delta = scores[:, 1] - scores[:, 0]
        delta[delta < 0] = 0
        score_changing.append((boundary_name, np.mean(delta)))
        if num_layers:
            layerwise_score_changing = []
            for l in range(num_layers):
                delta = scores[:, l + 2] - scores[:, 0]
                delta[delta < 0] = 0
                layerwise_score_changing.append(
                    (f'Layer {l:02d}', np.mean(delta)))
            layerwise_score_changing.sort(key=lambda x: x[1], reverse=True)
            for layer_name, delta_score in layerwise_score_changing:
                logger.info(f'  {layer_name}: {delta_score:7.4f}')
    logger.info(f'----------------------------------------')
    logger.info(f'Most relevant semantics:')
    score_changing.sort(key=lambda x: x[1], reverse=True)
    for boundary_name, delta_score in score_changing:
        logger.info(f'  {boundary_name.ljust(15)}: {delta_score:7.4f}')
Exemple #2
0
        error = 0
        for i in range(num):
            for j in range(step):
                error += np.average(np.abs(res[i, j] - res[i, 0] - diff * j))
        print('Error:', error)

        print('==== Layer-wise Manipulation Test (single latent code, '
              'single boundary) ====')
        num = 64
        start_distance = -10
        end_distance = -start_distance
        step = 21
        truncation_psi = 1.0
        truncation_layers = 10
        strength = get_layerwise_manipulation_strength(num_layers,
                                                       truncation_psi,
                                                       truncation_layers)
        indices = parse_indices('0-8, 10-12',
                                min_val=0,
                                max_val=num_layers - 1)
        x = np.random.randint(0, high=10000, size=(num, dim))
        b = np.random.randint(0, high=10000, size=(1, dim))
        res = manipulate(latent_codes=x,
                         boundary=b,
                         start_distance=start_distance,
                         end_distance=end_distance,
                         step=step,
                         layerwise_manipulation=True,
                         num_layers=num_layers,
                         manipulate_layers=indices,
                         is_code_layerwise=False,
Exemple #3
0
    logger.info(f'  Manipulating on layers `{manipulate_layers}`.')

    np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

    step = args.step + int(args.step % 2
                           == 0)  # Make sure it is an odd number.
    visualizer = HtmlPageVisualizer(num_rows=total_num, num_cols=step + 1)
    visualizer.set_headers([''] +
                           [f'Step {i - step // 2}'
                            for i in range(step // 2)] + ['Origin'] +
                           [f'Step {i + 1}' for i in range(step // 2)])
    for n in range(total_num):
        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

    strength = get_layerwise_manipulation_strength(model.num_layers,
                                                   model.truncation_psi,
                                                   model.truncation_layers)
    codes = manipulate(latent_codes=latent_codes['wp'],
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=model.num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=False,
                       layerwise_manipulation_strength=strength)
    np.save(os.path.join(work_dir, f'{prefix}_manipulated_wp.npy'), codes)

    for s in tqdm(range(step), leave=False):
Exemple #4
0
def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_manipulation'
    logger_name = f'{args.model_name}_manipulation_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    latent_codes = model.easy_synthesize(
        latent_codes=latent_codes,
        latent_space_type=args.latent_space_type,
        generate_style=False,
        generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    boundaries = parse_boundary_list(args.boundary_list_path)

    step = args.step + int(args.step % 2
                           == 0)  # Make sure it is an odd number.

    for boundary_info, boundary_path in boundaries.items():
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'

        if args.generate_html:
            viz_size = None if args.viz_size == 0 else args.viz_size
            visualizer = HtmlPageVisualizer(num_rows=total_num,
                                            num_cols=step + 1,
                                            viz_size=viz_size)
            visualizer.set_headers(
                [''] + [f'Step {i - step // 2}' for i in range(step // 2)] +
                ['Origin'] + [f'Step {i + 1}' for i in range(step // 2)])

        if args.generate_video:
            setup_images = model.easy_synthesize(
                latent_codes=latent_codes[args.latent_space_type],
                latent_space_type=args.latent_space_type)['image']
            fusion_kwargs = {
                'row': args.row,
                'col': args.col,
                'row_spacing': args.row_spacing,
                'col_spacing': args.col_spacing,
                'border_left': args.border_left,
                'border_right': args.border_right,
                'border_top': args.border_top,
                'border_bottom': args.border_bottom,
                'black_background': not args.white_background,
                'image_size': None if args.viz_size == 0 else args.viz_size,
            }
            setup_image = fuse_images(setup_images, **fusion_kwargs)
            video_writer = VideoWriter(os.path.join(
                work_dir, f'{prefix}_{args.video_name}'),
                                       frame_height=setup_image.shape[0],
                                       frame_width=setup_image.shape[1],
                                       fps=args.fps)

        logger.info(f'  Loading boundary.')
        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
            manipulate_layers = boundary_file['meta_data']['manipulate_layers']
        except ValueError:
            boundary = np.load(boundary_path)
            manipulate_layers = args.manipulate_layers
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if args.layerwise_manipulation and space_type != 'z':
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            if (not args.disable_manipulation_truncation
                ) and space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'
        else:
            if args.layerwise_manipulation:
                logger.warning(f'  Skip layer-wise manipulation for boundary '
                               f'`{boundary_name}` from Z space. Traditional '
                               f'manipulation is used instead.')
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            strength = 1.0

        codes = manipulate(latent_codes=latent_codes[space_type],
                           boundary=boundary,
                           start_distance=args.start_distance,
                           end_distance=args.end_distance,
                           step=step,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=manipulate_layers,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength)
        np.save(
            os.path.join(work_dir, f'{prefix}_manipulated_{space_type}.npy'),
            codes)

        logger.info(f'  Start manipulating.')
        for s in tqdm(range(step), leave=False):
            images = model.easy_synthesize(
                latent_codes=codes[:,
                                   s], latent_space_type=space_type)['image']
            if args.generate_video:
                video_writer.write(fuse_images(images, **fusion_kwargs))
            for n, image in enumerate(images):
                if args.save_raw_synthesis:
                    save_image(
                        os.path.join(work_dir,
                                     f'{prefix}_{n:05d}_{s:03d}.jpg'), image)
                if args.generate_html:
                    visualizer.set_cell(n, s + 1, image=image)
                    if s == 0:
                        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

        if args.generate_html:
            visualizer.save(
                os.path.join(work_dir, f'{prefix}_{args.html_name}'))