示例#1
0
def test_transformer_smoke(subtests, image_medium):
    transformer = paper.transformer()
    assert isinstance(transformer, paper.Transformer)

    with subtests.test("forward size"):
        output_image = transformer(image_medium)
        assert image_medium.size() == output_image.size()
示例#2
0
def test_stylization_main_smoke(subtests, images, dataset, stylization,
                                stylization_script, stylization_args):
    stylization_script.main(stylization_args)

    assert stylization.call_count == len(stylization_args.content)

    with subtests.test("input_image"):
        for call_args in stylization.call_args_list:
            args, _ = call_args
            input_image, _ = args
            assert isinstance(input_image, torch.Tensor)

    with subtests.test("content_image_loader"):
        transformer_type = type(paper.transformer())
        for call_args in stylization.call_args_list:
            args, _ = call_args
            _, transformer = args
            assert isinstance(transformer, transformer_type)
示例#3
0
def test_transformer_pretrained(subtests):
    @contextlib.contextmanager
    def patch(target, **kwargs):
        target = make_mock_target("johnson_alahi_li_2016", "_modules", target)
        with unittest.mock.patch(target, **kwargs) as mock:
            yield mock

    @contextlib.contextmanager
    def patch_select_url(url):
        with patch("select_url", return_value=url) as mock:
            yield mock

    @contextlib.contextmanager
    def patch_load_state_dict_from_url(state_dict):
        with patch("load_state_dict_from_url", return_value=state_dict) as mock:
            yield mock

    framework = "framework"
    style = "style"
    url = "url"
    for config in generate_param_combinations(
        impl_params=(True, False), instance_norm=(True, False)
    ):
        state_dict = paper.Transformer(**config).state_dict()
        with subtests.test(**config), patch_select_url(
            url
        ) as select_url, patch_load_state_dict_from_url(state_dict):
            transformer = paper.transformer(framework=framework, style=style, **config)

            with subtests.test("select_url"):
                kwargs = call_args_to_kwargs_only(
                    select_url.call_args,
                    "framework",
                    "style",
                    "impl_params",
                    "instance_norm",
                )
                assert kwargs["framework"] == framework
                assert kwargs["style"] == style
                assert kwargs["impl_params"] is config["impl_params"]
                assert kwargs["instance_norm"] is config["instance_norm"]

            ptu.assert_allclose(transformer.state_dict(), state_dict)
示例#4
0
def test_training_smoke(subtests, training, image_loader):
    args, kwargs, output = training(image_loader)
    content_image_loader, transformer, criterion, criterion_update_fn = args
    optimizer = kwargs["optimizer"]

    with subtests.test("content_image_loader"):
        assert content_image_loader is image_loader

    with subtests.test("transformer"):
        assert isinstance(transformer, type(paper.transformer()))

    with subtests.test("criterion"):
        assert isinstance(criterion, type(paper.perceptual_loss()))

    with subtests.test("criterion_update_fn"):
        assert is_callable(criterion_update_fn)

    with subtests.test("optimizer"):
        assert isinstance(optimizer, type(paper.optimizer(transformer)))

    with subtests.test("output"):
        assert output is transformer
示例#5
0
 def load(style=None):
     return paper.transformer(
         style,
         impl_params=impl_params,
         instance_norm=instance_norm,
     )