示例#1
0
def test_content_transform_grayscale_image(subtests, content_image,
                                           impl_params, instance_norm):
    content_image = F.rgb_to_grayscale(content_image)
    edge_size = 16

    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)
    hyper_parameters.content_transform.edge_size = edge_size

    content_transform = paper.content_transform(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    if instance_norm:
        utils.make_reproducible()
    actual = content_transform(content_image)

    if impl_params:
        if instance_norm:
            # Since the transform involves an uncontrollable random component, we can't
            # do a reference test here
            return

        transform_image = F.resize(content_image, edge_size)
    else:
        transform = transforms.CenterCrop(edge_size)
        transform_image = transform(content_image)

    desired = transform_image.repeat(1, 3, 1, 1)

    ptu.assert_allclose(actual, desired)
示例#2
0
def main(root):
    make_reproducible()
    save_versions(root)
    save_rng_states(root)

    generate_small_images(root)
    generate_medium_image(root)
    generate_large_image(root)
示例#3
0
def test_make_reproducible_cudnn(mocker):
    cudnn_mock = mocker.patch(
        "pystiche_papers.utils.misc.torch.backends.cudnn")
    cudnn_mock.is_available = lambda: True

    utils.make_reproducible()

    assert cudnn_mock.deterministic
    assert not cudnn_mock.benchmark
示例#4
0
def test_style_transform(subtests):
    make_reproducible()
    image = torch.rand(1, 1, 17, 32)

    style_transform = paper.style_transform()
    actual = style_transform(image)

    desired = F.resize(image, [136, 256])

    ptu.assert_allclose(actual, desired)
示例#5
0
def test_content_transform():
    make_reproducible()
    image = torch.rand(1, 1, 16, 31)

    content_transform = paper.content_transform(impl_params=False)
    actual = content_transform(image)

    desired = F.resize(image[:, :, :16, :16], 256).repeat(1, 3, 1, 1)

    ptu.assert_allclose(actual, desired)
示例#6
0
def test_make_reproducible_seeds(subtests, mocker):
    mocks = [(name, mocker.patch(f"pystiche_papers.utils.misc.{rel_import}"))
             for name, rel_import in (
                 ("standard library", "random.seed"),
                 ("numpy", "np.random.seed"),
                 ("torch", "torch.manual_seed"),
             )]

    seed = 123
    utils.make_reproducible(seed)

    for name, mock in mocks:
        assert mock.call_args[0][0] == seed
示例#7
0
def test_make_reproducible():
    def get_random_tensors():
        return torch.rand(10), torch.randn(10), torch.randint(10, (10, ))

    utils.make_reproducible()
    tensors1 = get_random_tensors()

    utils.make_reproducible()
    tensors2 = get_random_tensors()
    tensors3 = get_random_tensors()

    ptu.assert_allclose(tensors1, tensors2)

    with pytest.raises(AssertionError):
        ptu.assert_allclose(tensors2, tensors3)
示例#8
0
def test_make_reproducible_no_standard_library(mocker):
    mock = mocker.patch("pystiche_papers.utils.misc.random.seed")
    utils.make_reproducible(seed_standard_library=False)

    assert not mock.called
示例#9
0
def test_make_reproducible_uint32_seed():
    uint32_max = 2**32 - 1

    assert utils.make_reproducible(uint32_max) == uint32_max
    assert utils.make_reproducible(uint32_max + 1) == 0