def test_content_transform_grayscale_image(subtests, content_image, impl_params, instance_norm): content_image = F.rgb_to_grayscale(content_image) edge_size = 16 hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) hyper_parameters.content_transform.edge_size = edge_size content_transform = paper.content_transform( impl_params=impl_params, instance_norm=instance_norm, hyper_parameters=hyper_parameters, ) if instance_norm: utils.make_reproducible() actual = content_transform(content_image) if impl_params: if instance_norm: # Since the transform involves an uncontrollable random component, we can't # do a reference test here return transform_image = F.resize(content_image, edge_size) else: transform = transforms.CenterCrop(edge_size) transform_image = transform(content_image) desired = transform_image.repeat(1, 3, 1, 1) ptu.assert_allclose(actual, desired)
def main(root): make_reproducible() save_versions(root) save_rng_states(root) generate_small_images(root) generate_medium_image(root) generate_large_image(root)
def test_make_reproducible_cudnn(mocker): cudnn_mock = mocker.patch( "pystiche_papers.utils.misc.torch.backends.cudnn") cudnn_mock.is_available = lambda: True utils.make_reproducible() assert cudnn_mock.deterministic assert not cudnn_mock.benchmark
def test_style_transform(subtests): make_reproducible() image = torch.rand(1, 1, 17, 32) style_transform = paper.style_transform() actual = style_transform(image) desired = F.resize(image, [136, 256]) ptu.assert_allclose(actual, desired)
def test_content_transform(): make_reproducible() image = torch.rand(1, 1, 16, 31) content_transform = paper.content_transform(impl_params=False) actual = content_transform(image) desired = F.resize(image[:, :, :16, :16], 256).repeat(1, 3, 1, 1) ptu.assert_allclose(actual, desired)
def test_make_reproducible_seeds(subtests, mocker): mocks = [(name, mocker.patch(f"pystiche_papers.utils.misc.{rel_import}")) for name, rel_import in ( ("standard library", "random.seed"), ("numpy", "np.random.seed"), ("torch", "torch.manual_seed"), )] seed = 123 utils.make_reproducible(seed) for name, mock in mocks: assert mock.call_args[0][0] == seed
def test_make_reproducible(): def get_random_tensors(): return torch.rand(10), torch.randn(10), torch.randint(10, (10, )) utils.make_reproducible() tensors1 = get_random_tensors() utils.make_reproducible() tensors2 = get_random_tensors() tensors3 = get_random_tensors() ptu.assert_allclose(tensors1, tensors2) with pytest.raises(AssertionError): ptu.assert_allclose(tensors2, tensors3)
def test_make_reproducible_no_standard_library(mocker): mock = mocker.patch("pystiche_papers.utils.misc.random.seed") utils.make_reproducible(seed_standard_library=False) assert not mock.called
def test_make_reproducible_uint32_seed(): uint32_max = 2**32 - 1 assert utils.make_reproducible(uint32_max) == uint32_max assert utils.make_reproducible(uint32_max + 1) == 0