Beispiel #1
0
def benchmark_nst(images_root, results_root, device):
    def process_image(file):
        name = path.splitext(path.basename(file))[0]
        image = read_image(path.join(images_root, file)).to(device)
        return name, image

    content_files = get_npr_general_files()
    style_files = get_style_image_files()

    for content_file in content_files:
        content_name, content_image = process_image(content_file)
        for style_file in style_files:
            style_name, style_image = process_image(style_file)

            for ssim_loss in (False, True):
                output_image = perform_nst(
                    content_image, style_image, ssim_loss=ssim_loss, quiet=False
                )

                output_file = "__".join(
                    (content_name, style_name, "ssim" if ssim_loss else "se")
                )
                output_file = path.join(
                    results_root, "nst_benchmark", f"{output_file}.jpg"
                )
                write_image(output_image, output_file)
Beispiel #2
0
 def create_images(root):
     torch.manual_seed(0)
     files = {}
     for idx in range(3):
         name = str(idx)
         image = torch.rand(1, 3, 32, 32)
         file = path.join(root, f"{name}.png")
         write_image(image, file)
         files[name] = file
     return files
Beispiel #3
0
def test_write_image(tmpdir):
    torch.manual_seed(0)
    image = torch.rand(3, 100, 100)

    file = path.join(tmpdir, "tmp_image.png")
    image_.write_image(image, file)

    actual = image_.read_image(file=file)

    desired = image
    pyimagetest.assert_images_almost_equal(actual, desired)
Beispiel #4
0
def main(images, space_width_factor=0.1, annotation_width_factor=0.2):
    heights, widths = extract_image_sizes(images)
    pad_images(images, heights)

    space_width = int(space_width_factor * max(widths))
    banner, anchors = create_banner(images,
                                    heights,
                                    widths,
                                    space_width=space_width)

    annotation_width = int(annotation_width_factor *
                           min(max(heights), max(widths)))
    banner = annotate_banner(banner, anchors, width=annotation_width)
    write_image(banner, "banner.jpg")
Beispiel #5
0
def main(raw_args: Optional[List[str]] = None) -> None:
    args = parse_args(raw_args)
    config = make_config(args)

    config.perceptual_loss.set_content_image(config.content_image)
    config.perceptual_loss.set_style_image(config.style_image)

    output_image = image_optimization(config.input_image,
                                      config.perceptual_loss,
                                      num_steps=config.num_steps)

    if not config.output_image_specified:
        print(f"Saving result to {config.output_image}")
    write_image(output_image, config.output_image)
Beispiel #6
0
def figure_6(args):
    images = paper.images()
    images.download(args.image_source_dir)
    positions = ("top", "bottom")
    image_pairs = (
        (images["blue_bottle"], images["self-portrait"]),
        (images["s"], images["composition_viii"]),
    )

    for position, image_pair in zip(positions, image_pairs):
        content_image = image_pair[0].read(device=args.device)
        style_image = image_pair[1].read(device=args.device)

        print(
            f"Replicating the {position} half of figure 6 "
            f"with {'implementation' if args.impl_params else 'paper'} parameters"
        )

        hyper_parameters = paper.hyper_parameters(impl_params=args.impl_params)
        if args.impl_params:
            # https://github.com/pmeier/CNNMRF/blob/fddcf4d01e2a6ce201059d8bc38597f74a09ba3f/run_trans.lua#L66
            hyper_parameters.content_loss.layer = "relu4_2"
            hyper_parameters.target_transforms.num_scale_steps = 1
            hyper_parameters.target_transforms.num_rotate_steps = 1

        output_image = paper.nst(
            content_image,
            style_image,
            impl_params=args.impl_params,
            hyper_parameters=hyper_parameters,
        )
        filename = make_output_filename(
            "li_wand_2016",
            "fig_6",
            position,
            impl_params=args.impl_params,
        )
        output_file = path.join(args.image_results_dir, filename)
        print(f"Saving result to {output_file}")
        write_image(output_image, output_file)
        print("#" * int(os.environ.get("COLUMNS", "80")))
Beispiel #7
0
def training(args):
    contents = ("karya", "tiger", "neckarfront", "bird", "kitty")
    styles = (
        "candy",
        "the_scream",
        "jean_metzinger",
        "mosaic",
        "pleades",
        "starry",
        "turner",
    )

    dataset = paper.dataset(
        args.dataset_dir,
        impl_params=args.impl_params,
        instance_norm=args.instance_norm,
    )
    image_loader = paper.image_loader(
        dataset,
        impl_params=args.impl_params,
        instance_norm=args.instance_norm,
        pin_memory=str(args.device).startswith("cuda"),
    )

    images = paper.images()
    images.download(args.image_source_dir)

    for style in styles:
        style_image = images[style].read(device=args.device)

        transformer = paper.training(
            image_loader,
            style_image,
            impl_params=args.impl_params,
            instance_norm=args.instance_norm,
        )

        model_name = f"ulyanov_et_al_2016__{style}"
        if args.impl_params:
            model_name += "__impl_params"
        if args.instance_norm:
            model_name += "__instance_norm"
        utils.save_state_dict(transformer, model_name, root=args.model_dir)

        for content in contents:
            content_image = images[content].read(device=args.device)
            output_image = paper.stylization(
                content_image,
                transformer,
                impl_params=args.impl_params,
                instance_norm=args.instance_norm,
            )

            output_name = f"{style}_{content}"
            if args.impl_params:
                output_name += "__impl_params"
            if args.instance_norm:
                output_name += "__instance_norm"
            output_file = path.join(args.image_results_dir,
                                    f"{output_name}.png")
            image.write_image(output_image, output_file)
Beispiel #8
0
def save_result(output_image, output_file):
    print(f"Saving result to {output_file}")
    write_image(output_image, output_file)
    print("#" * int(os.environ.get("COLUMNS", "80")))
Beispiel #9
0
def save_ouput_image(image, root, content, style, impl_params, instance_norm):
    name = make_name(content, style, impl_params, instance_norm)
    write_image(image, path.join(root, f"{name}.jpg"))