コード例 #1
0
def test_target_transforms_call_smoke(target_image):
    hyper_parameters = paper.hyper_parameters(impl_params=True)
    hyper_parameters.target_transforms.num_scale_steps = 1
    hyper_parameters.target_transforms.num_rotate_steps = 1
    for transform in paper.target_transforms(
            impl_params=True, hyper_parameters=hyper_parameters):
        assert isinstance(transform(target_image), torch.Tensor)
コード例 #2
0
def test_image_pyramid(subtests, mocker, impl_params):
    OctaveImagePyramid = pyramid.OctaveImagePyramid
    spy = mocker.patch(
        "pystiche_papers.li_wand_2016._pyramid.pyramid.OctaveImagePyramid",
        wraps=OctaveImagePyramid,
    )

    image_pyramid = paper.image_pyramid(impl_params=impl_params)
    assert isinstance(image_pyramid, OctaveImagePyramid)

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params).image_pyramid
    args = utils.call_args_to_namespace(spy.call_args, OctaveImagePyramid)

    with subtests.test("max_edge_size"):
        assert args.max_edge_size == hyper_parameters.max_edge_size

    with subtests.test("num_steps"):
        assert args.num_steps == hyper_parameters.num_steps

    with subtests.test("num_levels"):
        assert args.num_levels == hyper_parameters.num_levels

    with subtests.test("min_edge_size"):
        assert args.min_edge_size == hyper_parameters.min_edge_size

    with subtests.test("edge"):
        assert args.edge == hyper_parameters.edge
コード例 #3
0
def test_style_loss(subtests, impl_params):
    style_loss = paper.style_loss(impl_params=impl_params)
    assert isinstance(style_loss, pystiche.loss.MultiLayerEncodingLoss)

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params).style_loss

    with subtests.test("losses"):
        assert all(
            isinstance(loss, paper.MRFLoss) for loss in style_loss.children())

    layers, layer_weights, patch_size, stride = zip(
        *[(loss.encoder.layer, loss.score_weight, loss.patch_size, loss.stride)
          for loss in style_loss.children()])
    with subtests.test("layers"):
        assert layers == hyper_parameters.layers

    with subtests.test("layer_weights"):
        assert layer_weights == pytest.approx((1.0, ) * len(layers))

    with subtests.test("patch_size"):
        assert patch_size == (misc.to_2d_arg(
            hyper_parameters.patch_size), ) * len(layers)

    with subtests.test("stride"):
        assert stride == (misc.to_2d_arg(
            hyper_parameters.stride), ) * len(layers)

    with subtests.test("score_weight"):
        assert style_loss.score_weight == pytest.approx(
            hyper_parameters.score_weight)
コード例 #4
0
def test_hyper_parameters_nst(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "nst"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("starting_point"):
        assert hyper_parameters.starting_point == "content" if impl_params else "random"
コード例 #5
0
def test_regularization(subtests, impl_params):
    regularization_loss = paper.regularization(impl_params=impl_params)
    assert isinstance(regularization_loss, paper.TotalVariationLoss)

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params).regularization

    with subtests.test("score_weight"):
        assert regularization_loss.score_weight == pytest.approx(
            hyper_parameters.score_weight)
コード例 #6
0
def test_hyper_parameters_content_loss(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "content_loss"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("layer"):
        assert hyper_parameters.layer == "relu4_1" if impl_params else "relu4_2"

    with subtests.test("score_weight"):
        assert hyper_parameters.score_weight == pytest.approx(
            2e1 if impl_params else 1e0)
コード例 #7
0
def test_content_loss(subtests, impl_params):
    content_loss = paper.content_loss(impl_params=impl_params)
    assert isinstance(content_loss, paper.FeatureReconstructionLoss)

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params).content_loss

    with subtests.test("layer"):
        assert content_loss.encoder.layer == hyper_parameters.layer

    with subtests.test("score_weight"):
        assert content_loss.score_weight == pytest.approx(
            hyper_parameters.score_weight)
コード例 #8
0
def test_li_wand_2016_nst_smoke(subtests, mocker, content_image, style_image):
    spy = mocker.patch(
        mocks.make_mock_target("li_wand_2016", "_nst", "misc",
                               "get_input_image"),
        wraps=get_input_image,
    )
    mock = mocker.patch(
        mocks.make_mock_target("li_wand_2016", "_nst", "optim",
                               "pyramid_image_optimization"))

    hyper_parameters = paper.hyper_parameters()

    paper.nst(content_image, style_image)

    args, kwargs = mock.call_args
    input_image, criterion, pyramid = args
    get_optimizer = kwargs["get_optimizer"]
    preprocessor = kwargs["preprocessor"]
    postprocessor = kwargs["postprocessor"]
    initial_resize = pyramid[-1].resize_image

    with subtests.test("input_image"):
        args = utils.call_args_to_namespace(spy.call_args, get_input_image)
        assert args.starting_point == hyper_parameters.nst.starting_point
        assert extract_image_size(args.content_image) == extract_image_size(
            initial_resize(content_image))

    with subtests.test("style_image"):
        desired_style_image = preprocessor(initial_resize(style_image))
        for loss in criterion.style_loss.children():
            ptu.assert_allclose(loss.target_image, desired_style_image)

    with subtests.test("criterion"):
        assert isinstance(criterion, type(paper.perceptual_loss()))

    with subtests.test("pyramid"):
        assert isinstance(pyramid, type(paper.image_pyramid()))

    with subtests.test("optimizer"):
        assert is_callable(get_optimizer)
        optimizer = get_optimizer(input_image)
        assert isinstance(optimizer, type(paper.optimizer(input_image)))

    with subtests.test("preprocessor"):
        assert isinstance(preprocessor, type(paper.preprocessor()))

    with subtests.test("postprocessor"):
        assert isinstance(postprocessor, type(paper.postprocessor()))
コード例 #9
0
def test_hyper_parameters_target_transforms(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "target_transforms"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("num_scale_steps"):
        assert hyper_parameters.num_scale_steps == 0 if impl_params else 3

    with subtests.test("scale_step_width"):
        assert hyper_parameters.scale_step_width == pytest.approx(5e-2)

    with subtests.test("num_rotate_steps"):
        assert hyper_parameters.num_rotate_steps == 0 if impl_params else 2

    with subtests.test("rotate_step_width"):
        assert hyper_parameters.rotate_step_width == pytest.approx(7.5)
コード例 #10
0
def figure_6(args):
    images = paper.images()
    images.download(args.image_source_dir)
    positions = ("top", "bottom")
    image_pairs = (
        (images["blue_bottle"], images["self-portrait"]),
        (images["s"], images["composition_viii"]),
    )

    for position, image_pair in zip(positions, image_pairs):
        content_image = image_pair[0].read(device=args.device)
        style_image = image_pair[1].read(device=args.device)

        print(
            f"Replicating the {position} half of figure 6 "
            f"with {'implementation' if args.impl_params else 'paper'} parameters"
        )

        hyper_parameters = paper.hyper_parameters(impl_params=args.impl_params)
        if args.impl_params:
            # https://github.com/pmeier/CNNMRF/blob/fddcf4d01e2a6ce201059d8bc38597f74a09ba3f/run_trans.lua#L66
            hyper_parameters.content_loss.layer = "relu4_2"
            hyper_parameters.target_transforms.num_scale_steps = 1
            hyper_parameters.target_transforms.num_rotate_steps = 1

        output_image = paper.nst(
            content_image,
            style_image,
            impl_params=args.impl_params,
            hyper_parameters=hyper_parameters,
        )
        filename = make_output_filename(
            "li_wand_2016",
            "fig_6",
            position,
            impl_params=args.impl_params,
        )
        output_file = path.join(args.image_results_dir, filename)
        print(f"Saving result to {output_file}")
        write_image(output_image, output_file)
        print("#" * int(os.environ.get("COLUMNS", "80")))
コード例 #11
0
def test_hyper_parameters_style_loss(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "style_loss"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("layer"):
        assert hyper_parameters.layers == ("relu3_1", "relu4_1")

    with subtests.test("layer_weights"):
        assert hyper_parameters.layer_weights == "sum"

    with subtests.test("patch_size"):
        assert hyper_parameters.patch_size == 3

    with subtests.test("stride"):
        assert hyper_parameters.stride == 2 if impl_params else 1

    with subtests.test("score_weight"):
        assert hyper_parameters.score_weight == pytest.approx(
            1e-4 if impl_params else 1e0)
コード例 #12
0
def test_hyper_parameters_image_pyramid(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "image_pyramid"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("max_edge_size"):
        assert hyper_parameters.max_edge_size == 384

    with subtests.test("num_steps"):
        assert hyper_parameters.num_steps == 100 if impl_params else 200

    with subtests.test("num_levels"):
        if impl_params:
            assert hyper_parameters.num_levels == 3
        else:
            assert hyper_parameters.num_levels is None

    with subtests.test("min_edge_size"):
        assert hyper_parameters.min_edge_size == 64

    with subtests.test("edge"):
        assert hyper_parameters.edge == "long"
コード例 #13
0
def test_hyper_parameters(subtests):
    hyper_parameters = paper.hyper_parameters()
    assert isinstance(hyper_parameters, HyperParameters)