Ejemplo n.º 1
0
def test_nst_smoke(subtests, default_image_pyramid_optim_loop_patch,
                   content_image, style_image):
    patch = default_image_pyramid_optim_loop_patch

    paper.nst(content_image, style_image)

    args, kwargs = patch.call_args
    input_image, criterion, image_pyramid = args
    get_optimizer = kwargs["get_optimizer"]
    preprocessor = kwargs["preprocessor"]
    postprocessor = kwargs["postprocessor"]

    with subtests.test("input_image"):
        ptu.assert_allclose(input_image,
                            image_pyramid[-1].resize_image(content_image))

    with subtests.test("criterion"):
        assert isinstance(criterion, type(paper.perceptual_loss()))

    with subtests.test("image_pyramid"):
        assert isinstance(image_pyramid, type(paper.image_pyramid()))

    with subtests.test("optimizer"):
        assert is_callable(get_optimizer)
        optimizer = get_optimizer(input_image)
        assert isinstance(optimizer, type(paper.optimizer(input_image)))

    with subtests.test("preprocessor"):
        assert isinstance(preprocessor, type(paper.preprocessor()))

    with subtests.test("postprocessor"):
        assert isinstance(postprocessor, type(paper.postprocessor()))
Ejemplo n.º 2
0
    def figure_3_d(content_image, style_image):
        content_image_yuv = rgb_to_yuv(content_image)
        content_luminance = content_image_yuv[:, :1].repeat(1, 3, 1, 1)
        content_chromaticity = content_image_yuv[:, 1:]

        style_luminance = rgb_to_grayscale(style_image, num_output_channels=3)

        print("Replicating Figure 3 (d)")
        output_luminance = paper.nst(
            content_luminance,
            style_luminance,
            impl_params=args.impl_params,
        )
        output_luminance = torch.mean(output_luminance, dim=1, keepdim=True)
        output_chromaticity = resize(content_chromaticity,
                                     output_luminance.size()[2:])
        output_image_yuv = torch.cat((output_luminance, output_chromaticity),
                                     dim=1)
        output_image = yuv_to_rgb(output_image_yuv)
        filename = make_output_filename(
            "gatys_et_al_2017",
            "fig_3",
            "d",
            impl_params=args.impl_params,
        )
        output_file = path.join(args.image_results_dir, filename)
        save_result(output_image, output_file)
Ejemplo n.º 3
0
    def figure_2_d(content_image, style_image):
        print("Replicating Figure 2 (d)")
        output_image = paper.nst(
            content_image,
            style_image,
            impl_params=args.impl_params,
        )

        output_file = path.join(args.image_results_dir, "fig_2__d.jpg")
        save_result(output_image, output_file)
Ejemplo n.º 4
0
    def figure_3_e(content_image, style_image, method="cholesky"):
        style_image = match_channelwise_statistics(style_image, content_image,
                                                   method)

        print("Replicating Figure 3 (e)")
        output_image = paper.nst(
            content_image,
            style_image,
            impl_params=args.impl_params,
        )

        output_file = path.join(args.image_results_dir, "fig_3__e.jpg")
        save_result(output_image, output_file)
Ejemplo n.º 5
0
 def figure_3_c(content_image, style_image):
     print("Replicating Figure 3 (c)")
     output_image = paper.nst(
         content_image,
         style_image,
         impl_params=args.impl_params,
     )
     filename = make_output_filename(
         "gatys_et_al_2017",
         "fig_3",
         "c",
         impl_params=args.impl_params,
     )
     output_file = path.join(args.image_results_dir, filename)
     save_result(output_image, output_file)
Ejemplo n.º 6
0
    def figure_3_e(content_image, style_image, method="cholesky"):
        style_image = match_channelwise_statistics(style_image, content_image,
                                                   method)

        print("Replicating Figure 3 (e)")
        output_image = paper.nst(
            content_image,
            style_image,
            impl_params=args.impl_params,
        )
        filename = make_output_filename(
            "gatys_et_al_2017",
            "fig_3",
            "e",
            impl_params=args.impl_params,
        )
        output_file = path.join(args.image_results_dir, filename)
        save_result(output_image, output_file)