Example #1
0
def get_gan_model(model_name):
    """
    :param model_name: Please refer `GAN_MODELS`
    :return: gan_model(nn.Module or nn.Sequential)
    """
    gan = build_generator(model_name)
    if model_name.startswith('pggan'):
        gan_list = list(gan.net.children())
        remove_index = PGGAN_Inter_Output_Layer_1024 if model_name == 'pggan_celebahq' else PGGAN_Inter_Output_Layer_256
        for output_index in remove_index:
            gan_list.pop(output_index)
        return nn.Sequential(*gan_list)
    elif model_name.startswith('style'):
        return gan
Example #2
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-dir", default="expr/semantics_layerwise")
    parser.add_argument("--out-dir", default="results/layerplot/layer_loss")
    parser.add_argument("--place", default="paper", help="paper | appendix")
    parser.add_argument("--viz-model", default="LSE")
    parser.add_argument("--repeat", default=2, type=int)
    parser.add_argument("--gpu-id", default=0, type=int)
    parser.add_argument("--show-weight", default=0, type=int)
    args = parser.parse_args()

    G_names = "pggan_celebahq,stylegan_celebahq,stylegan2_ffhq,pggan_church,stylegan_church,stylegan2_church,pggan_bedroom,stylegan_bedroom,stylegan2_bedroom"
    # setup and constants
    data_dir = args.model_dir
    Gs = {G_name: build_generator(G_name).net for G_name in G_names.split(",")}
    is_face = "ffhq" in G_names or "celebahq" in G_names
    unet = FaceSegmenter()

    def P_from_name(G_name):
        if "ffhq" in G_name or "celebahq" in G_name:
            return unet
        return SceneSegmenter(model_name=G_name)

    Ps = {G_name: P_from_name(G_name) for G_name in G_names.split(",")}
    label_list = CELEBA_CATEGORY if is_face else []
    n_class = len(label_list)
    N_repeat = args.repeat
    model_dirs = glob.glob(f"{data_dir}/*")
    model_files = [d for d in model_dirs if os.path.isdir(d)]
    model_files = [glob.glob(f"{f}/*.pth") for f in model_files]
Example #3
0
def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_rescore'
    logger_name = f'{args.model_name}_rescore_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if args.num <= 0:
        raise ValueError(f'Argument `num` should be specified as a positive '
                         f'number, but `{args.num}` received!')
    latent_codes = model.easy_sample(num=args.num, latent_space_type='z')
    latent_codes = model.easy_synthesize(latent_codes=latent_codes,
                                         latent_space_type='z',
                                         generate_style=False,
                                         generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    logger.info(f'Initializing predictor.')
    predictor = build_predictor(args.predictor_name)

    boundaries = parse_boundary_list(args.boundary_list_path)

    logger.info(f'========================================')
    logger.info(f'Rescoring.')
    score_changing = []
    for boundary_info, boundary_path in boundaries.items():
        logger.info(f'----------------------------------------')
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'
        attr_idx = predictor.attribute_name_to_idx[boundary_name]

        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
        except ValueError:
            boundary = np.load(boundary_path)

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if space_type == 'z':
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            num_layers = 0
            strength = 1.0
        else:
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            num_layers = model.num_layers if args.layerwise_rescoring else 0
            if space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'

        codes = []
        codes.append(latent_codes[space_type][:, np.newaxis])
        for l in range(-1, num_layers):
            codes.append(
                manipulate(latent_codes[space_type],
                           boundary,
                           start_distance=2.0,
                           end_distance=2.0,
                           step=1,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=None if l < 0 else l,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength))
        codes = np.concatenate(codes, axis=1)

        scores = []
        for i in tqdm(range(args.num), leave=False):
            images = model.easy_synthesize(latent_codes=codes[i],
                                           latent_space_type=space_type,
                                           generate_style=False,
                                           generate_image=True)['image']
            scores.append(
                predictor.easy_predict(images)['attribute'][:, attr_idx])
        scores = np.stack(scores, axis=0)
        np.save(os.path.join(work_dir, f'{prefix}_scores.npy'), scores)

        delta = scores[:, 1] - scores[:, 0]
        delta[delta < 0] = 0
        score_changing.append((boundary_name, np.mean(delta)))
        if num_layers:
            layerwise_score_changing = []
            for l in range(num_layers):
                delta = scores[:, l + 2] - scores[:, 0]
                delta[delta < 0] = 0
                layerwise_score_changing.append(
                    (f'Layer {l:02d}', np.mean(delta)))
            layerwise_score_changing.sort(key=lambda x: x[1], reverse=True)
            for layer_name, delta_score in layerwise_score_changing:
                logger.info(f'  {layer_name}: {delta_score:7.4f}')
    logger.info(f'----------------------------------------')
    logger.info(f'Most relevant semantics:')
    score_changing.sort(key=lambda x: x[1], reverse=True)
    for boundary_name, delta_score in score_changing:
        logger.info(f'  {boundary_name.ljust(15)}: {delta_score:7.4f}')
Example #4
0
                        type=str,
                        default='6-11',
                        help='Indices of the layers to perform manipulation. '
                        'Active ONLY when `layerwise_manipulation` is set '
                        'as `True`. If not specified, all layers will be '
                        'manipulated. More than one layers should be '
                        'separated by `,`. (default: None)')
    args = parser.parse_args()

    work_dir = 'manipulation_results'
    os.makedirs(work_dir, exist_ok=True)
    prefix = f'{args.model_name}_{args.boundary_name}'
    logger = setup_logger(work_dir, '', 'logger')

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    latent_codes = model.easy_synthesize(
Example #5
0
def setup_generator(model_name):
    global generator
    setup_model(model_name)
    generator = build_generator(models[model_name]['name'])
Example #6
0
    # PGGAN Generator.
    if args.pggan or args.all:
        print('==== PGGAN Generator Test ====')
        if args.verbose:
            model_list = []
            for model_name, model_setting in MODEL_POOL.items():
                if model_setting['gan_type'] == 'pggan':
                    model_list.append(model_name)
        else:
            model_list = ['pggan_celebahq', 'pggan_bedroom']
        for model_name in model_list:
            logger = setup_logger(
                work_dir=RESULT_DIR,
                logfile_name=f'{model_name}_generator_test.log',
                logger_name=f'{model_name}_generator_logger')
            G = build_generator(model_name, logger=logger)
            G.batch_size = TEST_BATCH_SIZE
            z = G.easy_sample(args.test_num)
            x = G.easy_synthesize(z)['image']
            visualizer = HtmlPageVisualizer(grid_size=args.test_num)
            for i in range(visualizer.num_rows):
                for j in range(visualizer.num_cols):
                    visualizer.set_cell(i,
                                        j,
                                        image=x[i * visualizer.num_cols + j])
            visualizer.save(f'{RESULT_DIR}/{model_name}_generator_test.html')
            del G
        print('Pass!')
        TEST_FLAG = True

    # StyleGAN Generator.
Example #7
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    image_dir = args.image_dir
    image_dir_name = os.path.basename(image_dir.rstrip('/'))
    assert os.path.exists(image_dir)
    assert os.path.exists(f'{image_dir}/image_list.txt')
    assert os.path.exists(f'{image_dir}/inverted_codes.npy')
    boundary_path = args.boundary_path
    assert os.path.exists(boundary_path)
    boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
    output_dir = args.output_dir or 'results/manipulation'
    job_name = f'{boundary_name.upper()}_{image_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)

    # Load image, codes, and boundary.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    image_list = []
    with open(f'{image_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{image_dir}/{name}_ori.png')
            assert os.path.exists(f'{image_dir}/{name}_inv.png')
            image_list.append(name)
    latent_codes = np.load(f'{image_dir}/inverted_codes.npy')
    assert latent_codes.shape[0] == len(image_list)
    num_images = latent_codes.shape[0]
    logger.info(f'Loading boundary.')
    boundary_file = np.load(boundary_path, allow_pickle=True)[()]
    if isinstance(boundary_file, dict):
        boundary = boundary_file['boundary']
        manipulate_layers = boundary_file['meta_data']['manipulate_layers']
    else:
        boundary = boundary_file
        manipulate_layers = args.manipulate_layers
    if manipulate_layers:
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')
    else:
        logger.info(f'  Manipulating on ALL layers.')

    # Manipulate images.
    logger.info(f'Start manipulation.')
    step = args.step
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_images,
                                    num_cols=step + 3,
                                    viz_size=viz_size)
    visualizer.set_headers(['Name', 'Origin', 'Inverted'] +
                           [f'Step {i:02d}' for i in range(1, step + 1)])
    for img_idx, img_name in enumerate(image_list):
        ori_image = load_image(f'{image_dir}/{img_name}_ori.png')
        inv_image = load_image(f'{image_dir}/{img_name}_inv.png')
        visualizer.set_cell(img_idx, 0, text=img_name)
        visualizer.set_cell(img_idx, 1, image=ori_image)
        visualizer.set_cell(img_idx, 2, image=inv_image)

    codes = manipulate(latent_codes=latent_codes,
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=generator.num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=True)

    for img_idx in tqdm(range(num_images), leave=False):
        output_images = generator.easy_synthesize(
            codes[img_idx], latent_space_type='wp')['image']
        for s, output_image in enumerate(output_images):
            visualizer.set_cell(img_idx, s + 3, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
Example #8
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    src_dir = args.src_dir
    src_dir_name = os.path.basename(src_dir.rstrip('/'))
    assert os.path.exists(src_dir)
    assert os.path.exists(f'{src_dir}/image_list.txt')
    assert os.path.exists(f'{src_dir}/inverted_codes.npy')
    dst_dir = args.dst_dir
    dst_dir_name = os.path.basename(dst_dir.rstrip('/'))
    assert os.path.exists(dst_dir)
    assert os.path.exists(f'{dst_dir}/image_list.txt')
    assert os.path.exists(f'{dst_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/interpolation'
    job_name = f'{src_dir_name}_TO_{dst_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    src_list = []
    with open(f'{src_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{src_dir}/{name}_ori.png')
            src_list.append(name)
    src_codes = np.load(f'{src_dir}/inverted_codes.npy')
    assert src_codes.shape[0] == len(src_list)
    num_src = src_codes.shape[0]
    dst_list = []
    with open(f'{dst_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{dst_dir}/{name}_ori.png')
            dst_list.append(name)
    dst_codes = np.load(f'{dst_dir}/inverted_codes.npy')
    assert dst_codes.shape[0] == len(dst_list)
    num_dst = dst_codes.shape[0]

    # Interpolate images.
    logger.info(f'Start interpolation.')
    step = args.step + 2
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_src * num_dst,
                                    num_cols=step + 2,
                                    viz_size=viz_size)
    visualizer.set_headers(['Source', 'Source Inversion'] +
                           [f'Step {i:02d}' for i in range(1, step - 1)] +
                           ['Target Inversion', 'Target'])

    for src_idx in tqdm(range(num_src), leave=False):
        src_code = src_codes[src_idx:src_idx + 1]
        src_path = f'{src_dir}/{src_list[src_idx]}_ori.png'
        codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0),
                            dst_codes=dst_codes,
                            step=step)
        for dst_idx in tqdm(range(num_dst), leave=False):
            dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png'
            output_images = generator.easy_synthesize(
                codes[dst_idx], latent_space_type='wp')['image']

            row_idx = src_idx * num_dst + dst_idx
            visualizer.set_cell(row_idx, 0, image=load_image(src_path))
            visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path))
            for s, output_image in enumerate(output_images):
                visualizer.set_cell(row_idx, s + 1, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
Example #9
0
  else:
    mIoU, c_iou = read_results(res_file)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument('--eval-file', type=str, default='results/scs',
    help='The path to the experiment result file.')
  parser.add_argument('--out-dir', type=str, default='results/scs',
    help='The output directory.')
  parser.add_argument('--gpu-id', default='0',
    help='Which GPU(s) to use. (default: `0`)')
  args = parser.parse_args()
  set_cuda_devices(args.gpu_id)
  Gs = {
      "stylegan2_ffhq" : build_generator("stylegan2_ffhq").net,
      "stylegan2_bedroom" : build_generator("stylegan2_bedroom").net,
      "stylegan2_church" :  build_generator("stylegan2_church").net
    }
  Ps = {
      "stylegan2_ffhq" : FaceSegmenter(),
      "stylegan2_bedroom" : SceneSegmenter(model_name="stylegan2_bedroom"),
      "stylegan2_church" : SceneSegmenter(model_name="stylegan2_church")
    }

  if ".pth" in args.eval_file:
    eval_single(Gs, Ps, args.eval_file)
  else:
    eval_files = glob.glob(f"{args.eval_file}/*.pth")
    eval_files.sort()
    for eval_file in eval_files:
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    style_dir = args.style_dir
    style_dir_name = os.path.basename(style_dir.rstrip('/'))
    assert os.path.exists(style_dir)
    assert os.path.exists(f'{style_dir}/image_list.txt')
    assert os.path.exists(f'{style_dir}/inverted_codes.npy')
    content_dir = args.content_dir
    content_dir_name = os.path.basename(content_dir.rstrip('/'))
    assert os.path.exists(content_dir)
    assert os.path.exists(f'{content_dir}/image_list.txt')
    assert os.path.exists(f'{content_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/style_mixing'
    job_name = f'{style_dir_name}_STYLIZE_{content_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    generator = build_generator(args.model_name)
    mix_layers = list(range(args.mix_layer_start_idx, generator.num_layers))

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    style_list = []
    with open(f'{style_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{style_dir}/{name}_ori.png')
            style_list.append(name)
    logger.info(f'Loading inverted latent codes.')
    style_codes = np.load(f'{style_dir}/inverted_codes.npy')
    assert style_codes.shape[0] == len(style_list)
    num_styles = style_codes.shape[0]
    content_list = []
    with open(f'{content_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{content_dir}/{name}_ori.png')
            content_list.append(name)
    logger.info(f'Loading inverted latent codes.')
    content_codes = np.load(f'{content_dir}/inverted_codes.npy')
    assert content_codes.shape[0] == len(content_list)
    num_contents = content_codes.shape[0]

    # Mix styles.
    logger.info(f'Start style mixing.')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_styles + 1,
                                    num_cols=num_contents + 1,
                                    viz_size=viz_size)
    visualizer.set_headers(['Style'] +
                           [f'Content {i:03d}' for i in range(num_contents)])
    for style_idx, style_name in enumerate(style_list):
        style_image = load_image(f'{style_dir}/{style_name}_ori.png')
        visualizer.set_cell(style_idx + 1, 0, image=style_image)
    for content_idx, content_name in enumerate(content_list):
        content_image = load_image(f'{content_dir}/{content_name}_ori.png')
        visualizer.set_cell(0, content_idx + 1, image=content_image)

    codes = mix_style(style_codes=style_codes,
                      content_codes=content_codes,
                      num_layers=generator.num_layers,
                      mix_layers=mix_layers)
    for style_idx in tqdm(range(num_styles), leave=False):
        output_images = generator.easy_synthesize(
            codes[style_idx], latent_space_type='wp')['image']
        for content_idx, output_image in enumerate(output_images):
            visualizer.set_cell(style_idx + 1,
                                content_idx + 1,
                                image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
Example #11
0
  from predictors.scene_segmenter import SceneSegmenter


  print(f"=> Loading from {args.SE}")
  if "baseline" in args.SE:
    SE_name = args.SE
    pred = True
  else:
    SE = load_semantic_extractor(args.SE)
    SE.cuda().eval()
    SE_name = args.SE[args.SE.rfind("/") + 1 : args.SE.rfind(".pth")]
    pred = False
  print(SE_name)
  G_name = listkey_convert(args.SE,
    ["stylegan2_ffhq", "stylegan2_bedroom", "stylegan2_church"])
  out_name = SE_name if pred else f"{G_name}_{SE_name}"
  if os.path.exists(f"{args.out_dir}/{out_name}.pth"):
    print(f"=> Skip {out_name}")
    exit(0)
  G = build_generator(G_name).net
  P = FaceSegmenter() if "ffhq" in G_name else SceneSegmenter(model_name=G_name)
  labels = read_labels(G_name, G, P)
  P_ = P if pred else SE
  if pred:
    out_name = SE_name
  z, wp = SCS.sseg_se(P_, G, labels,
    n_iter=args.n_iter, n_init=args.n_init,
    pred=pred, repeat=args.repeat,
    latent_strategy=args.latent_strategy)
  torch.save([z, wp], f"{args.out_dir}/{out_name}.pth")
Example #12
0
def load_model(model_name, logger=None):
    model_load_state = st.text('Loading GAN model...')
    model = build_generator(model_name, logger=logger)
    model_load_state.empty()
    return model
Example #13
0
def main():
    """Main function."""
    args = parse_args()
    file_name = args.output_file
    work_dir = args.output_dir or f'{args.model_name}_synthesis'
    logger_name = f'{args.model_name}_synthesis_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    if args.generate_prediction:
        logger.info(f'Initializing predictor.')
        predictor = build_predictor(args.predictor_name)

    if args.generate_html:
        viz_size = None if args.viz_size == 0 else args.viz_size
        visualizer = HtmlPageVisualizer(num_rows=args.html_row,
                                        num_cols=args.html_col,
                                        grid_size=total_num,
                                        viz_size=viz_size)

    logger.info(f'Generating {total_num} samples.')
    results = defaultdict(list)
    predictions = defaultdict(list)
    pbar = tqdm(total=total_num, leave=False)
    for inputs in model.get_batch_inputs(latent_codes):
        outputs = model.easy_synthesize(
            latent_codes=inputs,
            latent_space_type=args.latent_space_type,
            generate_style=args.generate_style,
            generate_image=not args.skip_image)
        for key, val in outputs.items():
            if key == 'image':
                if args.generate_prediction:
                    pred_outputs = predictor.easy_predict(val)
                    for pred_key, pred_val in pred_outputs.items():
                        predictions[pred_key].append(pred_val)
                for image in val:
                    if args.save_raw_synthesis:
                        dest = os.path.join(work_dir, f'{pbar.n:06d}.jpg')
                        if file_name != "":
                            dest = os.path.join(work_dir, file_name)
                        print('saving image to ', dest)
                        save_image(dest, image)
                    if args.generate_html:
                        row_idx = pbar.n // visualizer.num_cols
                        col_idx = pbar.n % visualizer.num_cols
                        visualizer.set_cell(row_idx, col_idx, image=image)
                    pbar.update(1)
            else:
                results[key].append(val)
        if 'image' not in outputs:
            pbar.update(inputs.shape[0])
    pbar.close()

    logger.info(f'Saving results.')
    if args.generate_html:
        visualizer.save(os.path.join(work_dir, args.html_name))
    for key, val in results.items():
        np.save(os.path.join(work_dir, f'{key}.npy'),
                np.concatenate(val, axis=0))
    if predictions:
        if args.predictor_name == 'scene':
            # Categories
            categories = np.concatenate(predictions['category'], axis=0)
            detailed_categories = {
                'score': categories,
                'name_to_idx': predictor.category_name_to_idx,
                'idx_to_name': predictor.category_idx_to_name,
            }
            np.save(os.path.join(work_dir, 'category.npy'),
                    detailed_categories)
            # Attributes
            attributes = np.concatenate(predictions['attribute'], axis=0)
            detailed_attributes = {
                'score': attributes,
                'name_to_idx': predictor.attribute_name_to_idx,
                'idx_to_name': predictor.attribute_idx_to_name,
            }
            np.save(os.path.join(work_dir, 'attribute.npy'),
                    detailed_attributes)
        else:
            for key, val in predictions.items():
                np.save(os.path.join(work_dir, f'{key}.npy'),
                        np.concatenate(val, axis=0))
Example #14
0
def main():
    """Main function."""
    args = parse_args()
    set_cuda_devices(args.gpu_id)
    work_dir = args.output_dir
    os.system("mkdir " + work_dir)

    model = build_generator(
        args.model_name,
        truncation_psi=None if args.truncation < 0 else args.truncation)

    if os.path.isfile(args.latent_codes_path):
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    if args.generate_prediction:
        predictor = build_predictor(args.predictor_name)

    if args.generate_html:
        visualizer = HtmlPageVisualizer(num_rows=args.html_row,
                                        num_cols=args.html_col,
                                        grid_size=total_num,
                                        viz_size=args.viz_size)

    results = defaultdict(list)
    predictions = defaultdict(list)
    pbar = tqdm(total=total_num, leave=False)
    for inputs in model.get_batch_inputs(latent_codes):
        outputs = model.easy_synthesize(
            latent_codes=inputs,
            latent_space_type=args.latent_space_type,
            generate_style=args.generate_style,
            generate_image=not args.skip_image)
        for key, val in outputs.items():
            if key == 'image':
                if args.generate_prediction:
                    pred_outputs = predictor.easy_predict(val)
                    for pred_key, pred_val in pred_outputs.items():
                        predictions[pred_key].append(pred_val)
                for image in val:
                    if args.save_raw_synthesis:
                        save_image(os.path.join(work_dir, f'{pbar.n:06d}.jpg'),
                                   image)
                    if args.generate_html:
                        row_idx = pbar.n // visualizer.num_cols
                        col_idx = pbar.n % visualizer.num_cols
                        visualizer.set_cell(row_idx,
                                            col_idx,
                                            text=f'Sample {pbar.n:06d}',
                                            image=image)
                    pbar.update(1)
            else:
                results[key].append(val)
        if 'image' not in outputs:
            pbar.update(inputs.shape[0])
    pbar.close()

    if args.generate_html:
        visualizer.save(os.path.join(work_dir, args.html_name))
    for key, val in results.items():
        np.save(os.path.join(work_dir, f'{key}.npy'),
                np.concatenate(val, axis=0))
    if predictions:
        print(len(predictions))
        predictor.save(predictions, work_dir)
Example #15
0
def main():
    """Main function."""
    args = parse_args()

    work_dir = args.output_dir or f'{args.model_name}_manipulation'
    logger_name = f'{args.model_name}_manipulation_logger'
    logger = setup_logger(work_dir, args.logfile_name, logger_name)

    logger.info(f'Initializing generator.')
    model = build_generator(args.model_name, logger=logger)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(
            latent_codes=latent_codes,
            latent_space_type=args.latent_space_type)
    else:
        if args.num <= 0:
            raise ValueError(
                f'Argument `num` should be specified as a positive '
                f'number since the latent code path '
                f'`{args.latent_codes_path}` does not exist!')
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(
            num=args.num, latent_space_type=args.latent_space_type)
    total_num = latent_codes.shape[0]

    latent_codes = model.easy_synthesize(
        latent_codes=latent_codes,
        latent_space_type=args.latent_space_type,
        generate_style=False,
        generate_image=False)
    for key, val in latent_codes.items():
        np.save(os.path.join(work_dir, f'{key}.npy'), val)

    boundaries = parse_boundary_list(args.boundary_list_path)

    step = args.step + int(args.step % 2
                           == 0)  # Make sure it is an odd number.

    for boundary_info, boundary_path in boundaries.items():
        boundary_name, space_type = boundary_info
        logger.info(
            f'Boundary `{boundary_name}` from {space_type.upper()} space.')
        prefix = f'{boundary_name}_{space_type}'

        if args.generate_html:
            viz_size = None if args.viz_size == 0 else args.viz_size
            visualizer = HtmlPageVisualizer(num_rows=total_num,
                                            num_cols=step + 1,
                                            viz_size=viz_size)
            visualizer.set_headers(
                [''] + [f'Step {i - step // 2}' for i in range(step // 2)] +
                ['Origin'] + [f'Step {i + 1}' for i in range(step // 2)])

        if args.generate_video:
            setup_images = model.easy_synthesize(
                latent_codes=latent_codes[args.latent_space_type],
                latent_space_type=args.latent_space_type)['image']
            fusion_kwargs = {
                'row': args.row,
                'col': args.col,
                'row_spacing': args.row_spacing,
                'col_spacing': args.col_spacing,
                'border_left': args.border_left,
                'border_right': args.border_right,
                'border_top': args.border_top,
                'border_bottom': args.border_bottom,
                'black_background': not args.white_background,
                'image_size': None if args.viz_size == 0 else args.viz_size,
            }
            setup_image = fuse_images(setup_images, **fusion_kwargs)
            video_writer = VideoWriter(os.path.join(
                work_dir, f'{prefix}_{args.video_name}'),
                                       frame_height=setup_image.shape[0],
                                       frame_width=setup_image.shape[1],
                                       fps=args.fps)

        logger.info(f'  Loading boundary.')
        try:
            boundary_file = np.load(boundary_path, allow_pickle=True).item()
            boundary = boundary_file['boundary']
            manipulate_layers = boundary_file['meta_data']['manipulate_layers']
        except ValueError:
            boundary = np.load(boundary_path)
            manipulate_layers = args.manipulate_layers
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')

        np.save(os.path.join(work_dir, f'{prefix}_boundary.npy'), boundary)

        if args.layerwise_manipulation and space_type != 'z':
            layerwise_manipulation = True
            is_code_layerwise = True
            is_boundary_layerwise = (space_type == 'wp')
            if (not args.disable_manipulation_truncation
                ) and space_type == 'w':
                strength = get_layerwise_manipulation_strength(
                    model.num_layers, model.truncation_psi,
                    model.truncation_layers)
            else:
                strength = 1.0
            space_type = 'wp'
        else:
            if args.layerwise_manipulation:
                logger.warning(f'  Skip layer-wise manipulation for boundary '
                               f'`{boundary_name}` from Z space. Traditional '
                               f'manipulation is used instead.')
            layerwise_manipulation = False
            is_code_layerwise = False
            is_boundary_layerwise = False
            strength = 1.0

        codes = manipulate(latent_codes=latent_codes[space_type],
                           boundary=boundary,
                           start_distance=args.start_distance,
                           end_distance=args.end_distance,
                           step=step,
                           layerwise_manipulation=layerwise_manipulation,
                           num_layers=model.num_layers,
                           manipulate_layers=manipulate_layers,
                           is_code_layerwise=is_code_layerwise,
                           is_boundary_layerwise=is_boundary_layerwise,
                           layerwise_manipulation_strength=strength)
        np.save(
            os.path.join(work_dir, f'{prefix}_manipulated_{space_type}.npy'),
            codes)

        logger.info(f'  Start manipulating.')
        for s in tqdm(range(step), leave=False):
            images = model.easy_synthesize(
                latent_codes=codes[:,
                                   s], latent_space_type=space_type)['image']
            if args.generate_video:
                video_writer.write(fuse_images(images, **fusion_kwargs))
            for n, image in enumerate(images):
                if args.save_raw_synthesis:
                    save_image(
                        os.path.join(work_dir,
                                     f'{prefix}_{n:05d}_{s:03d}.jpg'), image)
                if args.generate_html:
                    visualizer.set_cell(n, s + 1, image=image)
                    if s == 0:
                        visualizer.set_cell(n, 0, text=f'Sample {n:05d}')

        if args.generate_html:
            visualizer.save(
                os.path.join(work_dir, f'{prefix}_{args.html_name}'))