def main():
  """Main function."""
  args = parse_args()
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
  assert os.path.exists(args.image_list)
  image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
  output_dir = args.output_dir or f'results/inversion/{image_list_name}'
  logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger')

  logger.info(f'Loading model.')
  inverter = StyleGANInverter(
      args.model_name,
      learning_rate=args.learning_rate,
      iteration=args.num_iterations,
      reconstruction_loss_weight=1.0,
      perceptual_loss_weight=args.loss_weight_feat,
      regularization_loss_weight=args.loss_weight_enc,
      logger=logger)
  image_size = inverter.G.resolution

  # Load image list.
  logger.info(f'Loading image list.')
  image_list = []
  with open(args.image_list, 'r') as f:
    for line in f:
      image_list.append(line.strip())

  # Initialize visualizer.
  save_interval = args.num_iterations // args.num_results
  headers = ['Name', 'Original Image', 'Encoder Output']
  for step in range(1, args.num_iterations + 1):
    if step == args.num_iterations or step % save_interval == 0:
      headers.append(f'Step {step:06d}')
  viz_size = None if args.viz_size == 0 else args.viz_size
  visualizer = HtmlPageVisualizer(
      num_rows=len(image_list), num_cols=len(headers), viz_size=viz_size)
  visualizer.set_headers(headers)

  # Invert images.
  logger.info(f'Start inversion.')
  latent_codes = []
  for img_idx in tqdm(range(len(image_list)), leave=False):
    image_path = image_list[img_idx]
    image_name = os.path.splitext(os.path.basename(image_path))[0]
    image = resize_image(load_image(image_path), (image_size, image_size))
    code, viz_results = inverter.easy_invert(image, num_viz=args.num_results)
    latent_codes.append(code)
    save_image(f'{output_dir}/{image_name}_ori.png', image)
    save_image(f'{output_dir}/{image_name}_enc.png', viz_results[1])
    save_image(f'{output_dir}/{image_name}_inv.png', viz_results[-1])
    visualizer.set_cell(img_idx, 0, text=image_name)
    visualizer.set_cell(img_idx, 1, image=image)
    for viz_idx, viz_img in enumerate(viz_results[1:]):
      visualizer.set_cell(img_idx, viz_idx + 2, image=viz_img)

  # Save results.
  os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
  np.save(f'{output_dir}/inverted_codes.npy',
          np.concatenate(latent_codes, axis=0))
  visualizer.save(f'{output_dir}/inversion.html')
예제 #2
0
def Vis(bname, suffix, out, rownames=None, colnames=None):
    num_images = out.shape[0]
    step = out.shape[1]

    if colnames is None:
        colnames = [f'Step {i:02d}' for i in range(1, step + 1)]
    if rownames is None:
        rownames = [str(i) for i in range(num_images)]

    visualizer = HtmlPageVisualizer(num_rows=num_images,
                                    num_cols=step + 1,
                                    viz_size=256)
    visualizer.set_headers(['Name'] + colnames)

    for i in range(num_images):
        visualizer.set_cell(i, 0, text=rownames[i])

    for i in range(num_images):
        for k in range(step):
            image = out[i, k, :, :, :]
            visualizer.set_cell(i, 1 + k, image=image)

    # Save results.
    visualizer.save(f'./html/' + bname + '_' + suffix + '.html')
예제 #3
0
파일: test.py 프로젝트: yunhao1996/higan
        for model_name in model_list:
            logger = setup_logger(
                work_dir=RESULT_DIR,
                logfile_name=f'{model_name}_generator_test.log',
                logger_name=f'{model_name}_generator_logger')
            G = build_generator(model_name, logger=logger)
            G.batch_size = TEST_BATCH_SIZE
            z = G.easy_sample(args.test_num)
            x = G.easy_synthesize(z)['image']
            visualizer = HtmlPageVisualizer(grid_size=args.test_num)
            for i in range(visualizer.num_rows):
                for j in range(visualizer.num_cols):
                    visualizer.set_cell(i,
                                        j,
                                        image=x[i * visualizer.num_cols + j])
            visualizer.save(f'{RESULT_DIR}/{model_name}_generator_test.html')
            del G
        print('Pass!')
        TEST_FLAG = True

    # StyleGAN Generator.
    if args.stylegan or args.all:
        print('==== StyleGAN Generator Test ====')
        if args.verbose:
            model_list = []
            for model_name, model_setting in MODEL_POOL.items():
                if model_setting['gan_type'] == 'stylegan':
                    model_list.append(model_name)
        else:
            model_list = ['stylegan_ffhq', 'stylegan_car', 'stylegan_bedroom']
        for model_name in model_list:
예제 #4
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    src_dir = args.src_dir
    src_dir_name = os.path.basename(src_dir.rstrip('/'))
    assert os.path.exists(src_dir)
    assert os.path.exists(f'{src_dir}/image_list.txt')
    assert os.path.exists(f'{src_dir}/inverted_codes.npy')
    dst_dir = args.dst_dir
    dst_dir_name = os.path.basename(dst_dir.rstrip('/'))
    assert os.path.exists(dst_dir)
    assert os.path.exists(f'{dst_dir}/image_list.txt')
    assert os.path.exists(f'{dst_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/interpolation'
    job_name = f'{src_dir_name}_TO_{dst_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        _, _, _, Gs = pickle.load(f)

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
    wp = tf.placeholder(tf.float32, [args.batch_size, num_layers, latent_dim],
                        name='latent_code')
    x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    src_list = []
    with open(f'{src_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{src_dir}/{name}_ori.png')
            src_list.append(name)
    src_codes = np.load(f'{src_dir}/inverted_codes.npy')
    assert src_codes.shape[0] == len(src_list)
    num_src = src_codes.shape[0]
    dst_list = []
    with open(f'{dst_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{dst_dir}/{name}_ori.png')
            dst_list.append(name)
    dst_codes = np.load(f'{dst_dir}/inverted_codes.npy')
    assert dst_codes.shape[0] == len(dst_list)
    num_dst = dst_codes.shape[0]

    # Interpolate images.
    logger.info(f'Start interpolation.')
    step = args.step + 2
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_src * num_dst,
                                    num_cols=step + 2,
                                    viz_size=viz_size)
    visualizer.set_headers(['Source', 'Source Inversion'] +
                           [f'Step {i:02d}' for i in range(1, step - 1)] +
                           ['Target Inversion', 'Target'])

    inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
    for src_idx in tqdm(range(num_src), leave=False):
        src_code = src_codes[src_idx:src_idx + 1]
        src_path = f'{src_dir}/{src_list[src_idx]}_ori.png'
        codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0),
                            dst_codes=dst_codes,
                            step=step)
        for dst_idx in tqdm(range(num_dst), leave=False):
            dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png'
            output_images = []
            for idx in range(0, step, args.batch_size):
                batch = codes[dst_idx, idx:idx + args.batch_size]
                inputs[0:len(batch)] = batch
                images = sess.run(x, feed_dict={wp: inputs})
                output_images.append(images[0:len(batch)])
            output_images = adjust_pixel_range(
                np.concatenate(output_images, axis=0))

            row_idx = src_idx * num_dst + dst_idx
            visualizer.set_cell(row_idx, 0, image=load_image(src_path))
            visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path))
            for s, output_image in enumerate(output_images):
                if s == 5 and row_idx == 2:
                    save_image(f'./results/interpolation/005_int.png',
                               output_image)
                visualizer.set_cell(row_idx, s + 1, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #5
0
def main():
    """Main function."""
    args = parse_args()
    if args.num <= 0:
        return
    if not args.save_raw_synthesis and not args.generate_html:
        return

    # Parse model configuration.
    if args.model_name not in MODEL_ZOO:
        raise SystemExit(f'Model `{args.model_name}` is not registered in '
                         f'`models/model_zoo.py`!')
    model_config = MODEL_ZOO[args.model_name].copy()
    url = model_config.pop('url')  # URL to download model if needed.

    # Get work directory and job name.
    if args.save_dir:
        work_dir = args.save_dir
    else:
        work_dir = os.path.join('work_dirs', 'synthesis')
    os.makedirs(work_dir, exist_ok=True)
    job_name = f'{args.model_name}_{args.num}'
    if args.save_raw_synthesis:
        os.makedirs(os.path.join(work_dir, job_name), exist_ok=True)

    # Build generation and get synthesis kwargs.
    print(f'Building generator for model `{args.model_name}` ...')
    generator = build_generator(**model_config)
    synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
                            trunc_layers=args.trunc_layers,
                            randomize_noise=args.randomize_noise)
    print(f'Finish building generator.')

    # Load pre-trained weights.
    os.makedirs('checkpoints', exist_ok=True)
    checkpoint_path = os.path.join('checkpoints', args.model_name + '.pth')
    print(f'Loading checkpoint from `{checkpoint_path}` ...')
    if not os.path.exists(checkpoint_path):
        print(f'  Downloading checkpoint from `{url}` ...')
        subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
        print(f'  Finish downloading checkpoint.')
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    if 'generator_smooth' in checkpoint:
        generator.load_state_dict(checkpoint['generator_smooth'])
    else:
        generator.load_state_dict(checkpoint['generator'])
    generator = generator.cuda()
    generator.eval()
    print(f'Finish loading checkpoint.')

    # Set random seed.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Sample and synthesize.
    print(f'Synthesizing {args.num} samples ...')
    indices = list(range(args.num))
    if args.generate_html:
        html = HtmlPageVisualizer(grid_size=args.num)
    for batch_idx in tqdm(range(0, args.num, args.batch_size)):
        sub_indices = indices[batch_idx:batch_idx + args.batch_size]
        print(sub_indices)
        code = torch.randn(len(sub_indices), generator.z_space_dim).cuda()
        print(code.shape)
        with torch.no_grad():
            images = generator(code, **synthesis_kwargs)['image']
            images = postprocess(images)
        for sub_idx, image in zip(sub_indices, images):
            if args.save_raw_synthesis:
                save_path = os.path.join(work_dir, job_name,
                                         f'{sub_idx:06d}.jpg')
                save_image(save_path, image)
            if args.generate_html:
                row_idx, col_idx = divmod(sub_idx, html.num_cols)
                html.set_cell(row_idx,
                              col_idx,
                              image=image,
                              text=f'Sample {sub_idx:06d}')
    if args.generate_html:
        html.save(os.path.join(work_dir, f'{job_name}.html'))
    print(f'Finish synthesizing {args.num} samples.')
예제 #6
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    image_dir = args.image_dir
    image_dir_name = os.path.basename(image_dir.rstrip('/'))
    assert os.path.exists(image_dir)
    assert os.path.exists(f'{image_dir}/image_list.txt')
    assert os.path.exists(f'{image_dir}/inverted_codes.npy')
    boundary_path = args.boundary_path
    assert os.path.exists(boundary_path)
    boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
    output_dir = args.output_dir or 'results/manipulation'
    job_name = f'{boundary_name.upper()}_{image_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)

    # Load image, codes, and boundary.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    image_list = []
    with open(f'{image_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{image_dir}/{name}_ori.png')
            assert os.path.exists(f'{image_dir}/{name}_inv.png')
            image_list.append(name)
    latent_codes = np.load(f'{image_dir}/inverted_codes.npy')
    assert latent_codes.shape[0] == len(image_list)
    num_images = latent_codes.shape[0]
    logger.info(f'Loading boundary.')
    boundary_file = np.load(boundary_path, allow_pickle=True)[()]
    if isinstance(boundary_file, dict):
        boundary = boundary_file['boundary']
        manipulate_layers = boundary_file['meta_data']['manipulate_layers']
    else:
        boundary = boundary_file
        manipulate_layers = args.manipulate_layers
    if manipulate_layers:
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')
    else:
        logger.info(f'  Manipulating on ALL layers.')

    # Manipulate images.
    logger.info(f'Start manipulation.')
    step = args.step
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_images,
                                    num_cols=step + 3,
                                    viz_size=viz_size)
    visualizer.set_headers(['Name', 'Origin', 'Inverted'] +
                           [f'Step {i:02d}' for i in range(1, step + 1)])
    for img_idx, img_name in enumerate(image_list):
        ori_image = load_image(f'{image_dir}/{img_name}_ori.png')
        inv_image = load_image(f'{image_dir}/{img_name}_inv.png')
        visualizer.set_cell(img_idx, 0, text=img_name)
        visualizer.set_cell(img_idx, 1, image=ori_image)
        visualizer.set_cell(img_idx, 2, image=inv_image)

    codes = manipulate(latent_codes=latent_codes,
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=generator.num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=True)

    for img_idx in tqdm(range(num_images), leave=False):
        output_images = generator.easy_synthesize(
            codes[img_idx], latent_space_type='wp')['image']
        for s, output_image in enumerate(output_images):
            visualizer.set_cell(img_idx, s + 3, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #7
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.target_list)
    target_list_name = os.path.splitext(os.path.basename(args.target_list))[0]
    assert os.path.exists(args.context_list)
    context_list_name = os.path.splitext(os.path.basename(
        args.context_list))[0]
    output_dir = args.output_dir or f'results/diffusion'
    job_name = f'{target_list_name}_TO_{context_list_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]
    crop_size = args.crop_size
    crop_x = args.center_x - crop_size // 2
    crop_y = args.center_y - crop_size // 2
    mask = np.zeros((1, image_size, image_size, 3), dtype=np.float32)
    mask[:, crop_y:crop_y + crop_size, crop_x:crop_x + crop_size, :] = 1.0

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_mask = (tf.transpose(x, [0, 2, 3, 1]) + 1) * mask - 1
    x_mask_255 = (x_mask + 1) / 2 * 255
    latent_shape = Gs.components.synthesis.input_shape
    latent_shape[0] = args.batch_size
    wp = tf.get_variable(shape=latent_shape, name='latent_code')
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_mask = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) * mask - 1
    x_rec_mask_255 = (x_rec_mask + 1) / 2 * 255

    w_enc = E.get_output_for(x, phase=False)
    wp_enc = tf.reshape(w_enc, latent_shape)
    setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    logger.info(f'Setting configuration for optimization.')
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_mask_255)
    x_rec_feat = perceptual_model(x_rec_mask_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x_mask - x_rec_mask), axis=[1, 2, 3])

    loss = loss_pix + args.loss_weight_feat * loss_feat
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Load image list.
    logger.info(f'Loading target images and context images.')
    target_list = []
    with open(args.target_list, 'r') as f:
        for line in f:
            target_list.append(line.strip())
    num_targets = len(target_list)
    context_list = []
    with open(args.context_list, 'r') as f:
        for line in f:
            context_list.append(line.strip())
    num_contexts = len(context_list)
    num_pairs = num_targets * num_contexts

    # Invert images.
    logger.info(f'Start diffusion.')
    save_interval = args.num_iterations // args.num_results
    headers = [
        'Target Image', 'Context Image', 'Stitched Image', 'Encoder Output'
    ]
    for step in range(1, args.num_iterations + 1):
        if step == args.num_iterations or step % save_interval == 0:
            headers.append(f'Step {step:06d}')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_pairs,
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    latent_codes_enc = []
    latent_codes = []
    for target_idx in tqdm(range(num_targets), desc='Target ID', leave=False):
        # Load target.
        target_image = resize_image(load_image(target_list[target_idx]),
                                    (image_size, image_size))
        visualizer.set_cell(target_idx * num_contexts, 0, image=target_image)
        for context_idx in tqdm(range(0, num_contexts, args.batch_size),
                                desc='Context ID',
                                leave=False):
            row_idx = target_idx * num_contexts + context_idx
            batch = context_list[context_idx:context_idx + args.batch_size]
            for i, context_image_path in enumerate(batch):
                context_image = resize_image(load_image(context_image_path),
                                             (image_size, image_size))
                visualizer.set_cell(row_idx + i, 1, image=context_image)
                context_image[crop_y:crop_y + crop_size, crop_x:crop_x +
                              crop_size] = (target_image[crop_y:crop_y +
                                                         crop_size,
                                                         crop_x:crop_x +
                                                         crop_size])
                visualizer.set_cell(row_idx + i, 2, image=context_image)
                images[i] = np.transpose(context_image, [2, 0, 1])
            inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
            # Run encoder.
            sess.run([setter], {x: inputs})
            outputs = sess.run([wp, x_rec])
            latent_codes_enc.append(outputs[0][0:len(batch)])
            outputs[1] = adjust_pixel_range(outputs[1])
            for i, _ in enumerate(batch):
                visualizer.set_cell(row_idx + i, 3, image=outputs[1][i])
            # Optimize latent codes.
            col_idx = 4
            for step in tqdm(range(1, args.num_iterations + 1), leave=False):
                sess.run(train_op, {x: inputs})
                if step == args.num_iterations or step % save_interval == 0:
                    outputs = sess.run([wp, x_rec])
                    outputs[1] = adjust_pixel_range(outputs[1])
                    for i, _ in enumerate(batch):
                        visualizer.set_cell(row_idx + i,
                                            col_idx,
                                            image=outputs[1][i])
                    col_idx += 1
            latent_codes.append(outputs[0][0:len(batch)])

    # Save results.
    code_shape = [num_targets, num_contexts] + list(latent_shape[1:])
    np.save(f'{output_dir}/{job_name}_encoded_codes.npy',
            np.concatenate(latent_codes_enc, axis=0).reshape(code_shape))
    np.save(f'{output_dir}/{job_name}_inverted_codes.npy',
            np.concatenate(latent_codes, axis=0).reshape(code_shape))
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #8
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.image_list)
    image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
    output_dir = args.output_dir or f'results/ghfeat/{image_list_name}'
    logger = setup_logger(output_dir, 'extract_feature.log',
                          'inversion_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    G_args = EasyDict(func_name='training.networks_stylegan.G_synthesis')
    G_style_mod = tflib.Network('G_StyleMod',
                                resolution=image_size,
                                label_size=0,
                                **G_args)
    Gs_vars_pairs = {
        name: tflib.run(val)
        for name, val in Gs.components.synthesis.vars.items()
    }
    for g_name, g_val in G_style_mod.vars.items():
        tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    ghfeat = E.get_output_for(x, is_training=False)
    x_rec = G_style_mod.get_output_for(ghfeat, randomize_noise=False)

    # Load image list.
    logger.info(f'Loading image list.')
    image_list = []
    with open(args.image_list, 'r') as f:
        for line in f:
            image_list.append(line.strip())

    # Extract GH-Feat from images.
    logger.info(f'Start feature extraction.')
    headers = ['Name', 'Original Image', 'Encoder Output']
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=len(image_list),
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    names = ['' for _ in range(args.batch_size)]
    features = []
    for img_idx in tqdm(range(0, len(image_list), args.batch_size),
                        leave=False):
        # Load inputs.
        batch = image_list[img_idx:img_idx + args.batch_size]
        for i, image_path in enumerate(batch):
            image = resize_image(load_image(image_path),
                                 (image_size, image_size))
            images[i] = np.transpose(image, [2, 0, 1])
            names[i] = os.path.splitext(os.path.basename(image_path))[0]
        inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
        # Run encoder.
        outputs = sess.run([ghfeat, x_rec], {x: inputs})
        features.append(outputs[0][0:len(batch)])
        outputs[1] = adjust_pixel_range(outputs[1])
        for i, _ in enumerate(batch):
            image = np.transpose(images[i], [1, 2, 0])
            save_image(f'{output_dir}/{names[i]}_ori.png', image)
            save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i])
            visualizer.set_cell(i + img_idx, 0, text=names[i])
            visualizer.set_cell(i + img_idx, 1, image=image)
            visualizer.set_cell(i + img_idx, 2, image=outputs[1][i])

    # Save results.
    os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
    np.save(f'{output_dir}/ghfeat.npy', np.concatenate(features, axis=0))
    visualizer.save(f'{output_dir}/reconstruction.html')
예제 #9
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.image_list)
    image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
    output_dir = args.output_dir or f'results/inversion/{image_list_name}'
    logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_255 = (tf.transpose(x, [0, 2, 3, 1]) + 1) / 2 * 255
    latent_shape = Gs.components.synthesis.input_shape
    latent_shape[0] = args.batch_size
    wp = tf.get_variable(shape=latent_shape, name='latent_code')
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_255 = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) / 2 * 255
    if args.random_init:
        logger.info(f'  Use random initialization for optimization.')
        wp_rnd = tf.random.normal(shape=latent_shape, name='latent_code_init')
        setter = tf.assign(wp, wp_rnd)
    else:
        logger.info(
            f'  Use encoder output as the initialization for optimization.')
        w_enc = E.get_output_for(x, is_training=False)
        wp_enc = tf.reshape(w_enc, latent_shape)
        setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    logger.info(f'Setting configuration for optimization.')
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_255)
    x_rec_feat = perceptual_model(x_rec_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x - x_rec), axis=[1, 2, 3])
    if args.domain_regularizer:
        logger.info(f'  Involve encoder for optimization.')
        w_enc_new = E.get_output_for(x_rec, is_training=False)
        wp_enc_new = tf.reshape(w_enc_new, latent_shape)
        loss_enc = tf.reduce_mean(tf.square(wp - wp_enc_new), axis=[1, 2])
    else:
        logger.info(f'  Do NOT involve encoder for optimization.')
        loss_enc = 0
    loss = (loss_pix + args.loss_weight_feat * loss_feat +
            args.loss_weight_enc * loss_enc)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Load image list.
    logger.info(f'Loading image list.')
    image_list = []
    with open(args.image_list, 'r') as f:
        for line in f:
            image_list.append(line.strip())

    # Invert images.
    logger.info(f'Start inversion.')
    save_interval = args.num_iterations // args.num_results
    headers = ['Name', 'Original Image', 'Encoder Output']
    for step in range(1, args.num_iterations + 1):
        if step == args.num_iterations or step % save_interval == 0:
            headers.append(f'Step {step:06d}')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=len(image_list),
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    names = ['' for _ in range(args.batch_size)]
    latent_codes_enc = []
    latent_codes = []
    for img_idx in tqdm(range(0, len(image_list), args.batch_size),
                        leave=False):
        # Load inputs.
        batch = image_list[img_idx:img_idx + args.batch_size]
        for i, image_path in enumerate(batch):
            image = resize_image(load_image(image_path),
                                 (image_size, image_size))
            images[i] = np.transpose(image, [2, 0, 1])
            names[i] = os.path.splitext(os.path.basename(image_path))[0]
        inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
        # Run encoder.
        sess.run([setter], {x: inputs})
        outputs = sess.run([wp, x_rec])
        latent_codes_enc.append(outputs[0][0:len(batch)])
        outputs[1] = adjust_pixel_range(outputs[1])
        for i, _ in enumerate(batch):
            image = np.transpose(images[i], [1, 2, 0])
            save_image(f'{output_dir}/{names[i]}_ori.png', image)
            save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i])
            visualizer.set_cell(i + img_idx, 0, text=names[i])
            visualizer.set_cell(i + img_idx, 1, image=image)
            visualizer.set_cell(i + img_idx, 2, image=outputs[1][i])
        # Optimize latent codes.
        col_idx = 3
        for step in tqdm(range(1, args.num_iterations + 1), leave=False):
            sess.run(train_op, {x: inputs})
            if step == args.num_iterations or step % save_interval == 0:
                outputs = sess.run([wp, x_rec])
                outputs[1] = adjust_pixel_range(outputs[1])
                for i, _ in enumerate(batch):
                    if step == args.num_iterations:
                        save_image(f'{output_dir}/{names[i]}_inv.png',
                                   outputs[1][i])
                    visualizer.set_cell(i + img_idx,
                                        col_idx,
                                        image=outputs[1][i])
                col_idx += 1
        latent_codes.append(outputs[0][0:len(batch)])

    # Save results.
    os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
    np.save(f'{output_dir}/encoded_codes.npy',
            np.concatenate(latent_codes_enc, axis=0))
    np.save(f'{output_dir}/inverted_codes.npy',
            np.concatenate(latent_codes, axis=0))
    visualizer.save(f'{output_dir}/inversion.html')
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    style_dir = args.style_dir
    style_dir_name = os.path.basename(style_dir.rstrip('/'))
    assert os.path.exists(style_dir)
    assert os.path.exists(f'{style_dir}/image_list.txt')
    assert os.path.exists(f'{style_dir}/inverted_codes.npy')
    content_dir = args.content_dir
    content_dir_name = os.path.basename(content_dir.rstrip('/'))
    assert os.path.exists(content_dir)
    assert os.path.exists(f'{content_dir}/image_list.txt')
    assert os.path.exists(f'{content_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/style_mixing'
    job_name = f'{style_dir_name}_STYLIZE_{content_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)
    mix_layers = list(range(args.mix_layer_start_idx, generator.num_layers))

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    style_list = []
    with open(f'{style_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{style_dir}/{name}_ori.png')
            style_list.append(name)
    logger.info(f'Loading inverted latent codes.')
    style_codes = np.load(f'{style_dir}/inverted_codes.npy')
    assert style_codes.shape[0] == len(style_list)
    num_styles = style_codes.shape[0]
    content_list = []
    with open(f'{content_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{content_dir}/{name}_ori.png')
            content_list.append(name)
    logger.info(f'Loading inverted latent codes.')
    content_codes = np.load(f'{content_dir}/inverted_codes.npy')
    assert content_codes.shape[0] == len(content_list)
    num_contents = content_codes.shape[0]

    # Mix styles.
    logger.info(f'Start style mixing.')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_styles + 1,
                                    num_cols=num_contents + 1,
                                    viz_size=viz_size)
    visualizer.set_headers(['Style'] +
                           [f'Content {i:03d}' for i in range(num_contents)])
    for style_idx, style_name in enumerate(style_list):
        style_image = load_image(f'{style_dir}/{style_name}_ori.png')
        visualizer.set_cell(style_idx + 1, 0, image=style_image)
    for content_idx, content_name in enumerate(content_list):
        content_image = load_image(f'{content_dir}/{content_name}_ori.png')
        visualizer.set_cell(0, content_idx + 1, image=content_image)

    codes = mix_style(style_codes=style_codes,
                      content_codes=content_codes,
                      num_layers=generator.num_layers,
                      mix_layers=mix_layers)
    for style_idx in tqdm(range(num_styles), leave=False):
        output_images = generator.easy_synthesize(
            codes[style_idx], latent_space_type='wp')['image']
        for content_idx, output_image in enumerate(output_images):
            visualizer.set_cell(style_idx + 1,
                                content_idx + 1,
                                image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #11
0
def main():
    """Main function."""
    args = parse_args()
   
    if not args.save_raw_synthesis and not args.generate_html:
        return

    # Parse model configuration.
    if args.model_name not in MODEL_ZOO:
        raise SystemExit(f'Model `{args.model_name}` is not registered in '
                         f'`models/model_zoo.py`!')
    model_config = MODEL_ZOO[args.model_name].copy()
    url = model_config.pop('url')  # URL to download model if needed.

    # Get work directory and job name.
    if args.save_dir:
        work_dir = args.save_dir
    else:
        work_dir = os.path.join('work_dirs', 'synthesis')
    os.makedirs(work_dir, exist_ok=True)
    job_name = f'{args.model_name}_{args.num}'
    if args.save_raw_synthesis:
        os.makedirs(os.path.join(work_dir, job_name), exist_ok=True)

    # Build generation and get synthesis kwargs.
    print(f'Building generator for model `{args.model_name}` ...')
    generator = build_generator(**model_config)
    synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
                            trunc_layers=args.trunc_layers,
                            randomize_noise=args.randomize_noise)
    print(f'Finish building generator.')

    # Load pre-trained weights.
    os.makedirs('checkpoints', exist_ok=True)
    checkpoint_path = os.path.join('checkpoints', args.model_name + '.pth')
    print(f'Loading checkpoint from `{checkpoint_path}` ...')
    if not os.path.exists(checkpoint_path):
        print(f'  Downloading checkpoint from `{url}` ...')
        subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
        print(f'  Finish downloading checkpoint.')
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    if 'generator_smooth' in checkpoint:
        generator.load_state_dict(checkpoint['generator_smooth'])
    else:
        generator.load_state_dict(checkpoint['generator'])
    
    code_dim = generator.z_space_dim
    # generator = generator.cuda()
    device_ids = [int(i) for i in args.device_ids.split(',')]

    generator = torch.nn.DataParallel(generator, device_ids=device_ids).cuda()

    generator.eval()
    print(f'Finish loading checkpoint.')

    # Set random seed.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # load gt images:
    gt_imgs = load_data().cuda()
    print (gt_imgs.shape)

    total_num = gt_imgs.shape[0]
    # define the code
    code = np.zeros( ( total_num, code_dim ), dtype = np.float32)

    l1_loss = torch.nn.L1Loss().cuda()

   

    for batch_idx in tqdm(range(0, total_num, args.batch_size)): 
        frm_idx_start = batch_idx
        frm_idx_end = np.min( [ total_num, frm_idx_start + args.batch_size ] )
        sample_num = frm_idx_end - frm_idx_start

        print( 'Optim lightingcode based on pca batch [%d/%d], frame [%d/%d]'
               % (batch_idx/args.batch_size + 1, total_num, frm_idx_start, frm_idx_end) )

        # get code batch
        batch_code = [ ]
        for sample_id in range( sample_num  ):
            frm_idx = frm_idx_start + sample_id
            batch_code.append( code[ frm_idx ] )
        batch_code = np.asarray(batch_code)
        batch_code = torch.tensor( batch_code, dtype = torch.float32 ).cuda()
        batch_code.requires_grad = True

        # get gt image batch
        batch_gt_img =gt_imgs[frm_idx_start : frm_idx_end]

        # Define the optimizer
        code_optimizer = torch.optim.Adam( [
            { 'params': [batch_code], 'lr': args.lr }
        ] )

        # optimize
        for iter_id in range( args. max_iter_num ):
            images = generator(batch_code, **synthesis_kwargs)['image']
            # SIZE: (BATCH_SIZE , 3,1024,1024) 

            # calculate loss:
            print (batch_gt_img.max(),batch_gt_img.min(),'========')
            print (images.max(),images.min())

            global_pix_loss = (batch_gt_img  - images).abs().mean()
            code_reg_loss = (batch_code ** 2).mean()
            loss = global_pix_loss + code_reg_loss * 0.01

            # Optimize
            code_optimizer.zero_grad()
            loss.backward()
            code_optimizer.step()

            # Print errors
            if iter_id == 0 or ( iter_id + 1 ) % 100 == 0:
                print(  ' iter [%d/%d]: global %f, code_reg %f'
                   % (  iter_id + 1, args.max_iter_num ,  global_pix_loss.item(),
                       code_reg_loss.item()) )# , end = '\r' )

                # Reduce the lr
                code_optimizer.param_groups[ 0 ][ 'lr' ] *= 0.5
                print( ' Reduce learning rate to %f' % (code_optimizer.param_groups[ 0 ][ 'lr' ] ) )

        # Set the data
        n_code = batch_code.detach().cpu().numpy()
        for sample_id in range( sample_num ):
            frm_idx = frm_idx_start + sample_id
            code[ frm_idx ] = n_code[ sample_id ]


    # Sample and synthesize.
    print(f'Synthesizing {total_num} samples ...')
    indices = list(range(total_num))
    if args.generate_html:
        html = HtmlPageVisualizer(grid_size=total_num)
    for batch_idx in tqdm(range(0, total_num, args.batch_size)):
        sub_indices = indices[batch_idx:batch_idx + args.batch_size]
        print (code.shape)
        t_code = torch.FloatTensor(code[sub_indices]).cuda()
        with torch.no_grad():
            images = generator(t_code, **synthesis_kwargs)['image']
            # images shape [1,3,1024,1024]
            print (images.shape)

            images = postprocess(images)
        for sub_idx, image in zip(sub_indices, images):
            if args.save_raw_synthesis:
                save_path = os.path.join(
                    work_dir, job_name, f'{sub_idx:06d}.jpg')
                save_image(save_path, image)
            if args.generate_html:
                row_idx, col_idx = divmod(sub_idx, html.num_cols)
                html.set_cell(row_idx, col_idx, image=image,
                              text=f'Sample {sub_idx:06d}')
    if args.generate_html:
        html.save(os.path.join(work_dir, f'{job_name}.html'))
    print(f'Finish synthesizing {total_num} samples.')

     # Sample and ground truth image.
    print(f'Synthesizing {total_num} samples ...')
    indices = list(range(total_num))
    if args.generate_html:
        html = HtmlPageVisualizer(grid_size=total_num)
    for batch_idx in tqdm(range(0, total_num, args.batch_size)):
        sub_indices = indices[batch_idx:batch_idx + args.batch_size]
        with torch.no_grad():
            images = gt_imgs[sub_indices]
            # images shape [1,3,1024,1024]
            print (images.shape)

            images = postprocess(images)
        for sub_idx, image in zip(sub_indices, images):
            if args.save_raw_synthesis:
                save_path = os.path.join(
                    work_dir, job_name, f'{sub_idx:06d}_gt.jpg')
                save_image(save_path, image)
            if args.generate_html:
                row_idx, col_idx = divmod(sub_idx, html.num_cols)
                html.set_cell(row_idx, col_idx, image=image,
                              text=f'Sample {sub_idx:06d}')
    if args.generate_html:
        html.save(os.path.join(work_dir, f'{job_name}_gt.html'))
    print(f'Finish synthesizing {total_num} samples.')
def main():
  """Main function."""
  args = parse_args()
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
  assert os.path.exists(args.target_list)
  target_list_name = os.path.splitext(os.path.basename(args.target_list))[0]
  assert os.path.exists(args.context_list)
  context_list_name = os.path.splitext(os.path.basename(args.context_list))[0]
  output_dir = args.output_dir or f'results/diffusion'
  job_name = f'{target_list_name}_TO_{context_list_name}'
  logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

  logger.info(f'Loading model.')
  inverter = StyleGANInverter(
      args.model_name,
      learning_rate=args.learning_rate,
      iteration=args.num_iterations,
      reconstruction_loss_weight=1.0,
      perceptual_loss_weight=args.loss_weight_feat,
      regularization_loss_weight=0.0,
      logger=logger)
  image_size = inverter.G.resolution

  # Load image list.
  logger.info(f'Loading target images and context images.')
  target_list = []
  with open(args.target_list, 'r') as f:
    for line in f:
      target_list.append(line.strip())
  num_targets = len(target_list)
  context_list = []
  with open(args.context_list, 'r') as f:
    for line in f:
      context_list.append(line.strip())
  num_contexts = len(context_list)
  num_pairs = num_targets * num_contexts

  # Initialize visualizer.
  save_interval = args.num_iterations // args.num_results
  headers = ['Target Image', 'Context Image', 'Stitched Image',
             'Encoder Output']
  for step in range(1, args.num_iterations + 1):
    if step == args.num_iterations or step % save_interval == 0:
      headers.append(f'Step {step:06d}')
  viz_size = None if args.viz_size == 0 else args.viz_size
  visualizer = HtmlPageVisualizer(
      num_rows=num_pairs, num_cols=len(headers), viz_size=viz_size)
  visualizer.set_headers(headers)

  # Diffuse images.
  logger.info(f'Start diffusion.')
  latent_codes = []
  for target_idx in tqdm(range(num_targets), desc='Target ID', leave=False):
    # Load target.
    target_image = resize_image(load_image(target_list[target_idx]),
                                (image_size, image_size))
    visualizer.set_cell(target_idx * num_contexts, 0, image=target_image)
    for context_idx in tqdm(range(num_contexts), desc='Context ID',
                            leave=False):
      row_idx = target_idx * num_contexts + context_idx
      context_image = resize_image(load_image(context_list[context_idx]),
                                   (image_size, image_size))
      visualizer.set_cell(row_idx, 1, image=context_image)
      code, viz_results = inverter.easy_diffuse(target=target_image,
                                                context=context_image,
                                                center_x=args.center_x,
                                                center_y=args.center_y,
                                                crop_x=args.crop_size,
                                                crop_y=args.crop_size,
                                                num_viz=args.num_results)
      for viz_idx, viz_img in enumerate(viz_results):
        visualizer.set_cell(row_idx, viz_idx + 2, image=viz_img)
      latent_codes.append(code)

  # Save results.

  os.system(f'cp {args.target_list} {output_dir}/target_list.txt')
  os.system(f'cp {args.context_list} {output_dir}/context_list.txt')
  np.save(f'{output_dir}/{job_name}_inverted_codes.npy',
          np.concatenate(latent_codes, axis=0))
  visualizer.save(f'{output_dir}/{job_name}.html')
예제 #13
0
파일: synthesize.py 프로젝트: VogtAI/higan
def main():
    """Main function."""
    args = parse_args()
    file_name = args.output_file
    work_dir = args.output_dir or f'{args.model_name}_synthesis'
    logger_name = f'{args.model_name}_synthesis_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    if args.generate_prediction:
        logger.info(f'Initializing predictor.')
        predictor = build_predictor(args.predictor_name)

    if args.generate_html:
        viz_size = None if args.viz_size == 0 else args.viz_size
        visualizer = HtmlPageVisualizer(num_rows=args.html_row,
                                        num_cols=args.html_col,
                                        grid_size=total_num,
                                        viz_size=viz_size)

    logger.info(f'Generating {total_num} samples.')
    results = defaultdict(list)
    predictions = defaultdict(list)
    pbar = tqdm(total=total_num, leave=False)
    for inputs in model.get_batch_inputs(latent_codes):
        outputs = model.easy_synthesize(
            latent_codes=inputs,
            latent_space_type=args.latent_space_type,
            generate_style=args.generate_style,
            generate_image=not args.skip_image)
        for key, val in outputs.items():
            if key == 'image':
                if args.generate_prediction:
                    pred_outputs = predictor.easy_predict(val)
                    for pred_key, pred_val in pred_outputs.items():
                        predictions[pred_key].append(pred_val)
                for image in val:
                    if args.save_raw_synthesis:
                        dest = os.path.join(work_dir, f'{pbar.n:06d}.jpg')
                        if file_name != "":
                            dest = os.path.join(work_dir, file_name)
                        print('saving image to ', dest)
                        save_image(dest, image)
                    if args.generate_html:
                        row_idx = pbar.n // visualizer.num_cols
                        col_idx = pbar.n % visualizer.num_cols
                        visualizer.set_cell(row_idx, col_idx, image=image)
                    pbar.update(1)
            else:
                results[key].append(val)
        if 'image' not in outputs:
            pbar.update(inputs.shape[0])
    pbar.close()

    logger.info(f'Saving results.')
    if args.generate_html:
        visualizer.save(os.path.join(work_dir, args.html_name))
    for key, val in results.items():
        np.save(os.path.join(work_dir, f'{key}.npy'),
                np.concatenate(val, axis=0))
    if predictions:
        if args.predictor_name == 'scene':
            # Categories
            categories = np.concatenate(predictions['category'], axis=0)
            detailed_categories = {
                'score': categories,
                'name_to_idx': predictor.category_name_to_idx,
                'idx_to_name': predictor.category_idx_to_name,
            }
            np.save(os.path.join(work_dir, 'category.npy'),
                    detailed_categories)
            # Attributes
            attributes = np.concatenate(predictions['attribute'], axis=0)
            detailed_attributes = {
                'score': attributes,
                'name_to_idx': predictor.attribute_name_to_idx,
                'idx_to_name': predictor.attribute_idx_to_name,
            }
            np.save(os.path.join(work_dir, 'attribute.npy'),
                    detailed_attributes)
        else:
            for key, val in predictions.items():
                np.save(os.path.join(work_dir, f'{key}.npy'),
                        np.concatenate(val, axis=0))
예제 #14
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    image_dir = args.image_dir
    image_dir_name = os.path.basename(image_dir.rstrip('/'))
    assert os.path.exists(image_dir)
    assert os.path.exists(f'{image_dir}/image_list.txt')
    assert os.path.exists(f'{image_dir}/inverted_codes.npy')
    boundary_path = args.boundary_path
    assert os.path.exists(boundary_path)
    boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
    output_dir = args.output_dir or 'results/manipulation'
    job_name = f'{boundary_name.upper()}_{image_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        _, _, _, Gs = pickle.load(f)

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
    wp = tf.placeholder(tf.float32, [args.batch_size, num_layers, latent_dim],
                        name='latent_code')
    x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)

    # Load image, codes, and boundary.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    image_list = []
    with open(f'{image_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{image_dir}/{name}_ori.png')
            assert os.path.exists(f'{image_dir}/{name}_inv.png')
            image_list.append(name)
    latent_codes = np.load(f'{image_dir}/inverted_codes.npy')
    assert latent_codes.shape[0] == len(image_list)
    num_images = latent_codes.shape[0]
    logger.info(f'Loading boundary.')
    boundary_file = np.load(boundary_path, allow_pickle=True)[()]
    if isinstance(boundary_file, dict):
        boundary = boundary_file['boundary']
        manipulate_layers = boundary_file['meta_data']['manipulate_layers']
    else:
        boundary = boundary_file
        manipulate_layers = args.manipulate_layers
    if manipulate_layers:
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')
    else:
        logger.info(f'  Manipulating on ALL layers.')

    # Manipulate images.
    logger.info(f'Start manipulation.')
    step = args.step
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_images,
                                    num_cols=step + 3,
                                    viz_size=viz_size)
    visualizer.set_headers(['Name', 'Origin', 'Inverted'] +
                           [f'Step {i:02d}' for i in range(1, step + 1)])
    for img_idx, img_name in enumerate(image_list):
        ori_image = load_image(f'{image_dir}/{img_name}_ori.png')
        inv_image = load_image(f'{image_dir}/{img_name}_inv.png')
        visualizer.set_cell(img_idx, 0, text=img_name)
        visualizer.set_cell(img_idx, 1, image=ori_image)
        visualizer.set_cell(img_idx, 2, image=inv_image)

    codes = manipulate(latent_codes=latent_codes,
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=True)
    inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
    for img_idx in tqdm(range(num_images), leave=False):
        output_images = []
        for idx in range(0, step, args.batch_size):
            batch = codes[img_idx, idx:idx + args.batch_size]
            inputs[0:len(batch)] = batch
            images = sess.run(x, feed_dict={wp: inputs})
            output_images.append(images[0:len(batch)])
        output_images = adjust_pixel_range(
            np.concatenate(output_images, axis=0))
        for s, output_image in enumerate(output_images):
            visualizer.set_cell(img_idx, s + 3, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #15
0
    visualizer = HtmlPageVisualizer(num_rows=total_num, num_cols=step + 1)
    visualizer.set_headers([''] +
                           [f'Step {i - step // 2}'
                            for i in range(step // 2)] + ['Origin'] +
                           [f'Step {i + 1}' for i in range(step // 2)])
    for n in range(total_num):
        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

    strength = get_layerwise_manipulation_strength(model.num_layers,
                                                   model.truncation_psi,
                                                   model.truncation_layers)
    codes = manipulate(latent_codes=latent_codes['wp'],
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=model.num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=False,
                       layerwise_manipulation_strength=strength)
    np.save(os.path.join(work_dir, f'{prefix}_manipulated_wp.npy'), codes)

    for s in tqdm(range(step), leave=False):
        images = model.easy_synthesize(codes[:, s],
                                       latent_space_type='wp')['image']
        for n, image in enumerate(images):
            visualizer.set_cell(n, s + 1, image=image)
    visualizer.save(os.path.join(work_dir, f'{prefix}.html'))
예제 #16
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    src_dir = args.src_dir
    src_dir_name = os.path.basename(src_dir.rstrip('/'))
    assert os.path.exists(src_dir)
    assert os.path.exists(f'{src_dir}/image_list.txt')
    assert os.path.exists(f'{src_dir}/inverted_codes.npy')
    dst_dir = args.dst_dir
    dst_dir_name = os.path.basename(dst_dir.rstrip('/'))
    assert os.path.exists(dst_dir)
    assert os.path.exists(f'{dst_dir}/image_list.txt')
    assert os.path.exists(f'{dst_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/interpolation'
    job_name = f'{src_dir_name}_TO_{dst_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    src_list = []
    with open(f'{src_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{src_dir}/{name}_ori.png')
            src_list.append(name)
    src_codes = np.load(f'{src_dir}/inverted_codes.npy')
    assert src_codes.shape[0] == len(src_list)
    num_src = src_codes.shape[0]
    dst_list = []
    with open(f'{dst_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{dst_dir}/{name}_ori.png')
            dst_list.append(name)
    dst_codes = np.load(f'{dst_dir}/inverted_codes.npy')
    assert dst_codes.shape[0] == len(dst_list)
    num_dst = dst_codes.shape[0]

    # Interpolate images.
    logger.info(f'Start interpolation.')
    step = args.step + 2
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_src * num_dst,
                                    num_cols=step + 2,
                                    viz_size=viz_size)
    visualizer.set_headers(['Source', 'Source Inversion'] +
                           [f'Step {i:02d}' for i in range(1, step - 1)] +
                           ['Target Inversion', 'Target'])

    for src_idx in tqdm(range(num_src), leave=False):
        src_code = src_codes[src_idx:src_idx + 1]
        src_path = f'{src_dir}/{src_list[src_idx]}_ori.png'
        codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0),
                            dst_codes=dst_codes,
                            step=step)
        for dst_idx in tqdm(range(num_dst), leave=False):
            dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png'
            output_images = generator.easy_synthesize(
                codes[dst_idx], latent_space_type='wp')['image']

            row_idx = src_idx * num_dst + dst_idx
            visualizer.set_cell(row_idx, 0, image=load_image(src_path))
            visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path))
            for s, output_image in enumerate(output_images):
                visualizer.set_cell(row_idx, s + 1, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
예제 #17
0
    def synthesize(self,
                   num,
                   z=None,
                   html_name=None,
                   save_raw_synthesis=False):
        """Synthesizes images.

        Args:
            num: Number of images to synthesize.
            z: Latent codes used for generation. If not specified, this function
                will sample latent codes randomly. (default: None)
            html_name: Name of the output html page for visualization. If not
                specified, no visualization page will be saved. (default: None)
            save_raw_synthesis: Whether to save raw synthesis on the disk.
                (default: False)
        """
        if not html_name and not save_raw_synthesis:
            return

        self.set_mode('val')

        temp_dir = os.path.join(self.work_dir, 'synthesize_results')
        os.makedirs(temp_dir, exist_ok=True)

        if z is not None:
            assert isinstance(z, np.ndarray)
            assert z.ndim == 2 and z.shape[1] == self.z_space_dim
            num = min(num, z.shape[0])
            z = torch.from_numpy(z).type(torch.FloatTensor)
        if not num:
            return
        # TODO: Use same z during the entire training process.

        self.logger.init_pbar()
        task1 = self.logger.add_pbar_task('Synthesize', total=num)

        indices = list(range(self.rank, num, self.world_size))
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            if z is None:
                code = torch.randn(batch_size, self.z_space_dim).cuda()
            else:
                code = z[sub_indices].cuda()
            with torch.no_grad():
                if 'generator_smooth' in self.models:
                    G = self.models['generator_smooth']
                else:
                    G = self.models['generator']
                images = G(code, **self.G_kwargs_val)['image']
                images = self.postprocess(images)
            for sub_idx, image in zip(sub_indices, images):
                save_image(os.path.join(temp_dir, f'{sub_idx:06d}.jpg'), image)
            self.logger.update_pbar(task1, batch_size * self.world_size)

        dist.barrier()
        if self.rank != 0:
            return

        if html_name:
            task2 = self.logger.add_pbar_task('Visualize', total=num)
            html = HtmlPageVisualizer(grid_size=num)
            for image_idx in range(num):
                image = load_image(
                    os.path.join(temp_dir, f'{image_idx:06d}.jpg'))
                row_idx, col_idx = divmod(image_idx, html.num_cols)
                html.set_cell(row_idx,
                              col_idx,
                              image=image,
                              text=f'Sample {image_idx:06d}')
                self.logger.update_pbar(task2, 1)
            html.save(os.path.join(self.work_dir, html_name))
        if not save_raw_synthesis:
            shutil.rmtree(temp_dir)

        self.logger.close_pbar()
def main():
  """Main function."""
  args = parse_args()
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
  style_dir = args.style_dir
  style_dir_name = os.path.basename(style_dir.rstrip('/'))
  assert os.path.exists(style_dir)
  assert os.path.exists(f'{style_dir}/image_list.txt')
  assert os.path.exists(f'{style_dir}/inverted_codes.npy')
  content_dir = args.content_dir
  content_dir_name = os.path.basename(content_dir.rstrip('/'))
  assert os.path.exists(content_dir)
  assert os.path.exists(f'{content_dir}/image_list.txt')
  assert os.path.exists(f'{content_dir}/inverted_codes.npy')
  output_dir = args.output_dir or 'results/style_mixing'
  job_name = f'{style_dir_name}_STYLIZE_{content_dir_name}'
  logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

  # Load model.
  logger.info(f'Loading generator.')
  tflib.init_tf({'rnd.np_random_seed': 1000})
  with open(args.model_path, 'rb') as f:
    _, _, _, Gs = pickle.load(f)

  # Build graph.
  logger.info(f'Building graph.')
  sess = tf.get_default_session()
  num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
  wp = tf.placeholder(
      tf.float32, [args.batch_size, num_layers, latent_dim], name='latent_code')
  x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
  mix_layers = list(range(args.mix_layer_start_idx, num_layers))

  # Load image and codes.
  logger.info(f'Loading images and corresponding inverted latent codes.')
  style_list = []
  with open(f'{style_dir}/image_list.txt', 'r') as f:
    for line in f:
      name = os.path.splitext(os.path.basename(line.strip()))[0]
      assert os.path.exists(f'{style_dir}/{name}_ori.png')
      style_list.append(name)
  logger.info(f'Loading inverted latent codes.')
  style_codes = np.load(f'{style_dir}/inverted_codes.npy')
  assert style_codes.shape[0] == len(style_list)
  num_styles = style_codes.shape[0]
  content_list = []
  with open(f'{content_dir}/image_list.txt', 'r') as f:
    for line in f:
      name = os.path.splitext(os.path.basename(line.strip()))[0]
      assert os.path.exists(f'{content_dir}/{name}_ori.png')
      content_list.append(name)
  logger.info(f'Loading inverted latent codes.')
  content_codes = np.load(f'{content_dir}/inverted_codes.npy')
  assert content_codes.shape[0] == len(content_list)
  num_contents = content_codes.shape[0]

  # Mix styles.
  logger.info(f'Start style mixing.')
  viz_size = None if args.viz_size == 0 else args.viz_size
  visualizer = HtmlPageVisualizer(
      num_rows=num_styles + 1, num_cols=num_contents + 1, viz_size=viz_size)
  visualizer.set_headers(
      ['Style'] +
      [f'Content {i:03d}' for i in range(num_contents)]
  )
  for style_idx, style_name in enumerate(style_list):
    style_image = load_image(f'{style_dir}/{style_name}_ori.png')
    visualizer.set_cell(style_idx + 1, 0, image=style_image)
  for content_idx, content_name in enumerate(content_list):
    content_image = load_image(f'{content_dir}/{content_name}_ori.png')
    visualizer.set_cell(0, content_idx + 1, image=content_image)

  codes = mix_style(style_codes=style_codes,
                    content_codes=content_codes,
                    num_layers=num_layers,
                    mix_layers=mix_layers)
  inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
  for style_idx in tqdm(range(num_styles), leave=False):
    output_images = []
    for idx in range(0, num_contents, args.batch_size):
      batch = codes[style_idx, idx:idx + args.batch_size]
      inputs[0:len(batch)] = batch
      images = sess.run(x, feed_dict={wp: inputs})
      output_images.append(images[0:len(batch)])
    output_images = adjust_pixel_range(np.concatenate(output_images, axis=0))
    for content_idx, output_image in enumerate(output_images):
      visualizer.set_cell(style_idx + 1, content_idx + 1, image=output_image)

  # Save results.
  visualizer.save(f'{output_dir}/{job_name}.html')
예제 #19
0
def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_manipulation'
    logger_name = f'{args.model_name}_manipulation_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    latent_codes = model.easy_synthesize(
        latent_codes=latent_codes,
        latent_space_type=args.latent_space_type,
        generate_style=False,
        generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    boundaries = parse_boundary_list(args.boundary_list_path)

    step = args.step + int(args.step % 2
                           == 0)  # Make sure it is an odd number.

    for boundary_info, boundary_path in boundaries.items():
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'

        if args.generate_html:
            viz_size = None if args.viz_size == 0 else args.viz_size
            visualizer = HtmlPageVisualizer(num_rows=total_num,
                                            num_cols=step + 1,
                                            viz_size=viz_size)
            visualizer.set_headers(
                [''] + [f'Step {i - step // 2}' for i in range(step // 2)] +
                ['Origin'] + [f'Step {i + 1}' for i in range(step // 2)])

        if args.generate_video:
            setup_images = model.easy_synthesize(
                latent_codes=latent_codes[args.latent_space_type],
                latent_space_type=args.latent_space_type)['image']
            fusion_kwargs = {
                'row': args.row,
                'col': args.col,
                'row_spacing': args.row_spacing,
                'col_spacing': args.col_spacing,
                'border_left': args.border_left,
                'border_right': args.border_right,
                'border_top': args.border_top,
                'border_bottom': args.border_bottom,
                'black_background': not args.white_background,
                'image_size': None if args.viz_size == 0 else args.viz_size,
            }
            setup_image = fuse_images(setup_images, **fusion_kwargs)
            video_writer = VideoWriter(os.path.join(
                work_dir, f'{prefix}_{args.video_name}'),
                                       frame_height=setup_image.shape[0],
                                       frame_width=setup_image.shape[1],
                                       fps=args.fps)

        logger.info(f'  Loading boundary.')
        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
            manipulate_layers = boundary_file['meta_data']['manipulate_layers']
        except ValueError:
            boundary = np.load(boundary_path)
            manipulate_layers = args.manipulate_layers
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if args.layerwise_manipulation and space_type != 'z':
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            if (not args.disable_manipulation_truncation
                ) and space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'
        else:
            if args.layerwise_manipulation:
                logger.warning(f'  Skip layer-wise manipulation for boundary '
                               f'`{boundary_name}` from Z space. Traditional '
                               f'manipulation is used instead.')
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            strength = 1.0

        codes = manipulate(latent_codes=latent_codes[space_type],
                           boundary=boundary,
                           start_distance=args.start_distance,
                           end_distance=args.end_distance,
                           step=step,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=manipulate_layers,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength)
        np.save(
            os.path.join(work_dir, f'{prefix}_manipulated_{space_type}.npy'),
            codes)

        logger.info(f'  Start manipulating.')
        for s in tqdm(range(step), leave=False):
            images = model.easy_synthesize(
                latent_codes=codes[:,
                                   s], latent_space_type=space_type)['image']
            if args.generate_video:
                video_writer.write(fuse_images(images, **fusion_kwargs))
            for n, image in enumerate(images):
                if args.save_raw_synthesis:
                    save_image(
                        os.path.join(work_dir,
                                     f'{prefix}_{n:05d}_{s:03d}.jpg'), image)
                if args.generate_html:
                    visualizer.set_cell(n, s + 1, image=image)
                    if s == 0:
                        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

        if args.generate_html:
            visualizer.save(
                os.path.join(work_dir, f'{prefix}_{args.html_name}'))