Esempio n. 1
0
def auto_pad_conv_params():
    return tuple(
        generate_param_combinations(
            kernel_size=(3, 4, (3, 4), (4, 3)),
            stride=(1, 2, (1, 2), (2, 1)),
            dilation=(1, 2, (1, 2), (2, 1)),
        ))
Esempio n. 2
0
def test_get_conv(subtests):
    in_channels = out_channels = 3
    kernel_size = 3
    stride = 1
    for params in generate_param_combinations(
        padding=(None, 1), upsample=(True, False)
    ):
        with subtests.test(**params):
            conv = paper.conv(
                in_channels, out_channels, kernel_size, stride=stride, **params
            )

            assert isinstance(
                conv, nn.ConvTranspose2d if params["upsample"] else nn.Conv2d
            )

            with subtests.test("in_channels"):
                assert conv.in_channels == in_channels

            with subtests.test("out_channels"):
                assert conv.out_channels == out_channels

            with subtests.test("kernel_size"):
                assert conv.kernel_size == misc.to_2d_arg(kernel_size)

            with subtests.test("stride"):
                assert conv.stride == misc.to_2d_arg(stride)

            if params["padding"] is not None:
                with subtests.test("padding"):
                    assert conv.padding == misc.to_2d_arg(params["padding"])
Esempio n. 3
0
def model_url_configs(styles):
    return tuple(
        generate_param_combinations(
            framework=("pystiche", "luatorch"),
            style=styles,
            impl_params=(True, False),
            instance_norm=(True, False),
        )
    )
Esempio n. 4
0
def test_transformer_pretrained(subtests):
    @contextlib.contextmanager
    def patch(target, **kwargs):
        target = make_mock_target("johnson_alahi_li_2016", "_modules", target)
        with unittest.mock.patch(target, **kwargs) as mock:
            yield mock

    @contextlib.contextmanager
    def patch_select_url(url):
        with patch("select_url", return_value=url) as mock:
            yield mock

    @contextlib.contextmanager
    def patch_load_state_dict_from_url(state_dict):
        with patch("load_state_dict_from_url", return_value=state_dict) as mock:
            yield mock

    framework = "framework"
    style = "style"
    url = "url"
    for config in generate_param_combinations(
        impl_params=(True, False), instance_norm=(True, False)
    ):
        state_dict = paper.Transformer(**config).state_dict()
        with subtests.test(**config), patch_select_url(
            url
        ) as select_url, patch_load_state_dict_from_url(state_dict):
            transformer = paper.transformer(framework=framework, style=style, **config)

            with subtests.test("select_url"):
                kwargs = call_args_to_kwargs_only(
                    select_url.call_args,
                    "framework",
                    "style",
                    "impl_params",
                    "instance_norm",
                )
                assert kwargs["framework"] == framework
                assert kwargs["style"] == style
                assert kwargs["impl_params"] is config["impl_params"]
                assert kwargs["instance_norm"] is config["instance_norm"]

            ptu.assert_allclose(transformer.state_dict(), state_dict)
Esempio n. 5
0
def test_conv_block(subtests):
    in_channels = out_channels = 3
    kernel_size = 3
    stride = 1
    for params in generate_param_combinations(
        padding=(None, 1),
        upsample=(True, False),
        relu=(True, False),
        inplace=(True, False),
        instance_norm=(True, False),
    ):
        conv_block = paper.conv_block(
            in_channels, out_channels, kernel_size, stride=stride, **params
        )
        assert isinstance(conv_block, nn.Sequential)
        assert len(conv_block) == 3 if params["relu"] else 2

        with subtests.test("conv"):
            assert isinstance(
                conv_block[0],
                type(
                    paper.conv(
                        1, 1, 1, padding=params["padding"], upsample=params["upsample"]
                    )
                ),
            )

        with subtests.test("norm"):
            assert isinstance(
                conv_block[1], type(paper.norm(1, params["instance_norm"]))
            )

        if params["relu"]:
            with subtests.test("relu"):
                assert isinstance(conv_block[2], nn.ReLU)
                assert conv_block[2].inplace is params["inplace"]