def test_target_transforms_call_smoke(target_image): hyper_parameters = paper.hyper_parameters(impl_params=True) hyper_parameters.target_transforms.num_scale_steps = 1 hyper_parameters.target_transforms.num_rotate_steps = 1 for transform in paper.target_transforms( impl_params=True, hyper_parameters=hyper_parameters): assert isinstance(transform(target_image), torch.Tensor)
def test_image_pyramid(subtests, mocker, impl_params): OctaveImagePyramid = pyramid.OctaveImagePyramid spy = mocker.patch( "pystiche_papers.li_wand_2016._pyramid.pyramid.OctaveImagePyramid", wraps=OctaveImagePyramid, ) image_pyramid = paper.image_pyramid(impl_params=impl_params) assert isinstance(image_pyramid, OctaveImagePyramid) hyper_parameters = paper.hyper_parameters( impl_params=impl_params).image_pyramid args = utils.call_args_to_namespace(spy.call_args, OctaveImagePyramid) with subtests.test("max_edge_size"): assert args.max_edge_size == hyper_parameters.max_edge_size with subtests.test("num_steps"): assert args.num_steps == hyper_parameters.num_steps with subtests.test("num_levels"): assert args.num_levels == hyper_parameters.num_levels with subtests.test("min_edge_size"): assert args.min_edge_size == hyper_parameters.min_edge_size with subtests.test("edge"): assert args.edge == hyper_parameters.edge
def test_style_loss(subtests, impl_params): style_loss = paper.style_loss(impl_params=impl_params) assert isinstance(style_loss, pystiche.loss.MultiLayerEncodingLoss) hyper_parameters = paper.hyper_parameters( impl_params=impl_params).style_loss with subtests.test("losses"): assert all( isinstance(loss, paper.MRFLoss) for loss in style_loss.children()) layers, layer_weights, patch_size, stride = zip( *[(loss.encoder.layer, loss.score_weight, loss.patch_size, loss.stride) for loss in style_loss.children()]) with subtests.test("layers"): assert layers == hyper_parameters.layers with subtests.test("layer_weights"): assert layer_weights == pytest.approx((1.0, ) * len(layers)) with subtests.test("patch_size"): assert patch_size == (misc.to_2d_arg( hyper_parameters.patch_size), ) * len(layers) with subtests.test("stride"): assert stride == (misc.to_2d_arg( hyper_parameters.stride), ) * len(layers) with subtests.test("score_weight"): assert style_loss.score_weight == pytest.approx( hyper_parameters.score_weight)
def test_hyper_parameters_nst(subtests, impl_params): hyper_parameters = paper.hyper_parameters(impl_params=impl_params) sub_params = "nst" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("starting_point"): assert hyper_parameters.starting_point == "content" if impl_params else "random"
def test_regularization(subtests, impl_params): regularization_loss = paper.regularization(impl_params=impl_params) assert isinstance(regularization_loss, paper.TotalVariationLoss) hyper_parameters = paper.hyper_parameters( impl_params=impl_params).regularization with subtests.test("score_weight"): assert regularization_loss.score_weight == pytest.approx( hyper_parameters.score_weight)
def test_hyper_parameters_content_loss(subtests, impl_params): hyper_parameters = paper.hyper_parameters(impl_params=impl_params) sub_params = "content_loss" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("layer"): assert hyper_parameters.layer == "relu4_1" if impl_params else "relu4_2" with subtests.test("score_weight"): assert hyper_parameters.score_weight == pytest.approx( 2e1 if impl_params else 1e0)
def test_content_loss(subtests, impl_params): content_loss = paper.content_loss(impl_params=impl_params) assert isinstance(content_loss, paper.FeatureReconstructionLoss) hyper_parameters = paper.hyper_parameters( impl_params=impl_params).content_loss with subtests.test("layer"): assert content_loss.encoder.layer == hyper_parameters.layer with subtests.test("score_weight"): assert content_loss.score_weight == pytest.approx( hyper_parameters.score_weight)
def test_li_wand_2016_nst_smoke(subtests, mocker, content_image, style_image): spy = mocker.patch( mocks.make_mock_target("li_wand_2016", "_nst", "misc", "get_input_image"), wraps=get_input_image, ) mock = mocker.patch( mocks.make_mock_target("li_wand_2016", "_nst", "optim", "pyramid_image_optimization")) hyper_parameters = paper.hyper_parameters() paper.nst(content_image, style_image) args, kwargs = mock.call_args input_image, criterion, pyramid = args get_optimizer = kwargs["get_optimizer"] preprocessor = kwargs["preprocessor"] postprocessor = kwargs["postprocessor"] initial_resize = pyramid[-1].resize_image with subtests.test("input_image"): args = utils.call_args_to_namespace(spy.call_args, get_input_image) assert args.starting_point == hyper_parameters.nst.starting_point assert extract_image_size(args.content_image) == extract_image_size( initial_resize(content_image)) with subtests.test("style_image"): desired_style_image = preprocessor(initial_resize(style_image)) for loss in criterion.style_loss.children(): ptu.assert_allclose(loss.target_image, desired_style_image) with subtests.test("criterion"): assert isinstance(criterion, type(paper.perceptual_loss())) with subtests.test("pyramid"): assert isinstance(pyramid, type(paper.image_pyramid())) with subtests.test("optimizer"): assert is_callable(get_optimizer) optimizer = get_optimizer(input_image) assert isinstance(optimizer, type(paper.optimizer(input_image))) with subtests.test("preprocessor"): assert isinstance(preprocessor, type(paper.preprocessor())) with subtests.test("postprocessor"): assert isinstance(postprocessor, type(paper.postprocessor()))
def test_hyper_parameters_target_transforms(subtests, impl_params): hyper_parameters = paper.hyper_parameters(impl_params=impl_params) sub_params = "target_transforms" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("num_scale_steps"): assert hyper_parameters.num_scale_steps == 0 if impl_params else 3 with subtests.test("scale_step_width"): assert hyper_parameters.scale_step_width == pytest.approx(5e-2) with subtests.test("num_rotate_steps"): assert hyper_parameters.num_rotate_steps == 0 if impl_params else 2 with subtests.test("rotate_step_width"): assert hyper_parameters.rotate_step_width == pytest.approx(7.5)
def figure_6(args): images = paper.images() images.download(args.image_source_dir) positions = ("top", "bottom") image_pairs = ( (images["blue_bottle"], images["self-portrait"]), (images["s"], images["composition_viii"]), ) for position, image_pair in zip(positions, image_pairs): content_image = image_pair[0].read(device=args.device) style_image = image_pair[1].read(device=args.device) print( f"Replicating the {position} half of figure 6 " f"with {'implementation' if args.impl_params else 'paper'} parameters" ) hyper_parameters = paper.hyper_parameters(impl_params=args.impl_params) if args.impl_params: # https://github.com/pmeier/CNNMRF/blob/fddcf4d01e2a6ce201059d8bc38597f74a09ba3f/run_trans.lua#L66 hyper_parameters.content_loss.layer = "relu4_2" hyper_parameters.target_transforms.num_scale_steps = 1 hyper_parameters.target_transforms.num_rotate_steps = 1 output_image = paper.nst( content_image, style_image, impl_params=args.impl_params, hyper_parameters=hyper_parameters, ) filename = make_output_filename( "li_wand_2016", "fig_6", position, impl_params=args.impl_params, ) output_file = path.join(args.image_results_dir, filename) print(f"Saving result to {output_file}") write_image(output_image, output_file) print("#" * int(os.environ.get("COLUMNS", "80")))
def test_hyper_parameters_style_loss(subtests, impl_params): hyper_parameters = paper.hyper_parameters(impl_params=impl_params) sub_params = "style_loss" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("layer"): assert hyper_parameters.layers == ("relu3_1", "relu4_1") with subtests.test("layer_weights"): assert hyper_parameters.layer_weights == "sum" with subtests.test("patch_size"): assert hyper_parameters.patch_size == 3 with subtests.test("stride"): assert hyper_parameters.stride == 2 if impl_params else 1 with subtests.test("score_weight"): assert hyper_parameters.score_weight == pytest.approx( 1e-4 if impl_params else 1e0)
def test_hyper_parameters_image_pyramid(subtests, impl_params): hyper_parameters = paper.hyper_parameters(impl_params=impl_params) sub_params = "image_pyramid" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("max_edge_size"): assert hyper_parameters.max_edge_size == 384 with subtests.test("num_steps"): assert hyper_parameters.num_steps == 100 if impl_params else 200 with subtests.test("num_levels"): if impl_params: assert hyper_parameters.num_levels == 3 else: assert hyper_parameters.num_levels is None with subtests.test("min_edge_size"): assert hyper_parameters.min_edge_size == 64 with subtests.test("edge"): assert hyper_parameters.edge == "long"
def test_hyper_parameters(subtests): hyper_parameters = paper.hyper_parameters() assert isinstance(hyper_parameters, HyperParameters)