def test_transformer_smoke(subtests, image_medium): transformer = paper.transformer() assert isinstance(transformer, paper.Transformer) with subtests.test("forward size"): output_image = transformer(image_medium) assert image_medium.size() == output_image.size()
def test_stylization_main_smoke(subtests, images, dataset, stylization, stylization_script, stylization_args): stylization_script.main(stylization_args) assert stylization.call_count == len(stylization_args.content) with subtests.test("input_image"): for call_args in stylization.call_args_list: args, _ = call_args input_image, _ = args assert isinstance(input_image, torch.Tensor) with subtests.test("content_image_loader"): transformer_type = type(paper.transformer()) for call_args in stylization.call_args_list: args, _ = call_args _, transformer = args assert isinstance(transformer, transformer_type)
def test_transformer_pretrained(subtests): @contextlib.contextmanager def patch(target, **kwargs): target = make_mock_target("johnson_alahi_li_2016", "_modules", target) with unittest.mock.patch(target, **kwargs) as mock: yield mock @contextlib.contextmanager def patch_select_url(url): with patch("select_url", return_value=url) as mock: yield mock @contextlib.contextmanager def patch_load_state_dict_from_url(state_dict): with patch("load_state_dict_from_url", return_value=state_dict) as mock: yield mock framework = "framework" style = "style" url = "url" for config in generate_param_combinations( impl_params=(True, False), instance_norm=(True, False) ): state_dict = paper.Transformer(**config).state_dict() with subtests.test(**config), patch_select_url( url ) as select_url, patch_load_state_dict_from_url(state_dict): transformer = paper.transformer(framework=framework, style=style, **config) with subtests.test("select_url"): kwargs = call_args_to_kwargs_only( select_url.call_args, "framework", "style", "impl_params", "instance_norm", ) assert kwargs["framework"] == framework assert kwargs["style"] == style assert kwargs["impl_params"] is config["impl_params"] assert kwargs["instance_norm"] is config["instance_norm"] ptu.assert_allclose(transformer.state_dict(), state_dict)
def test_training_smoke(subtests, training, image_loader): args, kwargs, output = training(image_loader) content_image_loader, transformer, criterion, criterion_update_fn = args optimizer = kwargs["optimizer"] with subtests.test("content_image_loader"): assert content_image_loader is image_loader with subtests.test("transformer"): assert isinstance(transformer, type(paper.transformer())) with subtests.test("criterion"): assert isinstance(criterion, type(paper.perceptual_loss())) with subtests.test("criterion_update_fn"): assert is_callable(criterion_update_fn) with subtests.test("optimizer"): assert isinstance(optimizer, type(paper.optimizer(transformer))) with subtests.test("output"): assert output is transformer
def load(style=None): return paper.transformer( style, impl_params=impl_params, instance_norm=instance_norm, )