def test_get_input_image_tensor_style(): starting_point = "style" image = torch.tensor(0.0) actual = misc.get_input_image(starting_point, style_image=image) desired = image assert actual == ptu.approx(desired) with pytest.raises(RuntimeError): misc.get_input_image(starting_point, content_image=image)
def test_get_input_image_tensor(): image = torch.tensor(0.0) starting_point = image actual = misc.get_input_image(starting_point) desired = image assert actual is not desired ptu.assert_allclose(actual, desired)
def nst( content_image: torch.Tensor, style_image: torch.Tensor, impl_params: bool = True, hyper_parameters: Optional[HyperParameters] = None, quiet: bool = False, ) -> torch.Tensor: r"""NST from :cite:`LW2016`. Args: content_image: Content image for the NST. style_image: Style image for the NST. impl_params: Switch the behavior and hyper-parameters between the reference implementation of the original authors and what is described in the paper. For details see :ref:`here <li_wand_2016-impl_params>`. hyper_parameters: If omitted, :func:`~pystiche_papers.li_wand_2016.hyper_parameters` is used. quiet: If ``True``, not information is logged during the optimization. Defaults to ``False``. """ if hyper_parameters is None: hyper_parameters = _hyper_parameters(impl_params=impl_params) device = content_image.device criterion = perceptual_loss(impl_params=impl_params, hyper_parameters=hyper_parameters) criterion = criterion.to(device) image_pyramid = _image_pyramid(hyper_parameters=hyper_parameters, resize_targets=(criterion, )) initial_resize = image_pyramid[-1].resize_image content_image = initial_resize(content_image) style_image = initial_resize(style_image) input_image = misc.get_input_image( starting_point=hyper_parameters.nst.starting_point, content_image=content_image) preprocessor = _preprocessor().to(device) postprocessor = _postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) criterion.set_style_image(preprocessor(style_image)) return optim.pyramid_image_optimization( input_image, criterion, image_pyramid, get_optimizer=optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, )
def test_get_input_image_tensor_random(): starting_point = "random" content_image = torch.tensor(0.0, dtype=torch.float32) style_image = torch.tensor(0.0, dtype=torch.float64) actual = misc.get_input_image(starting_point, content_image=content_image) desired = content_image ptu.assert_tensor_attributes_equal(actual, desired) actual = misc.get_input_image(starting_point, style_image=style_image) desired = style_image ptu.assert_tensor_attributes_equal(actual, desired) actual = misc.get_input_image(starting_point, content_image=content_image, style_image=style_image) desired = content_image ptu.assert_tensor_attributes_equal(actual, desired) with pytest.raises(RuntimeError): misc.get_input_image(starting_point)
def nst( content_image: torch.Tensor, style_image: torch.Tensor, impl_params: bool = True, hyper_parameters: Optional[HyperParameters] = None, quiet: bool = False, ) -> torch.Tensor: r"""NST from :cite:`GEB2016`. Args: content_image: Content image for the NST. style_image: Style image for the NST. impl_params: Switch the behavior and hyper-parameters between the reference implementation of the original authors and what is described in the paper. For details see :ref:`here <gatys_ecker_bethge_2016-impl_params>`. hyper_parameters: If omitted, :func:`~pystiche_papers.gatys_ecker_bethge_2016.hyper_parameters` is used. quiet: If ``True``, no information is logged during the optimization. Defaults to ``False``. """ if hyper_parameters is None: hyper_parameters = _hyper_parameters() device = content_image.device criterion = perceptual_loss(impl_params=impl_params, hyper_parameters=hyper_parameters) criterion = criterion.to(device) input_image = misc.get_input_image( starting_point=hyper_parameters.nst.starting_point, content_image=content_image) preprocessor = _preprocessor().to(device) postprocessor = _postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) criterion.set_style_image(preprocessor(style_image)) torch.autograd.set_detect_anomaly(True) return optim.image_optimization( input_image, criterion, optimizer=optimizer, num_steps=hyper_parameters.nst.num_steps, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, )
def gatys_et_al_2017_nst( content_image: torch.Tensor, style_image: torch.Tensor, impl_params: bool = True, criterion: Optional[PerceptualLoss] = None, pyramid: Optional[ImagePyramid] = None, quiet: bool = False, logger: Optional[logging.Logger] = None, log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]], None]] = None, ) -> torch.Tensor: if criterion is None: criterion = gatys_et_al_2017_perceptual_loss(impl_params=impl_params) if pyramid is None: pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, )) device = content_image.device criterion = criterion.to(device) initial_resize = pyramid[-1].resize_image content_image = initial_resize(content_image) style_image = initial_resize(style_image) input_image = get_input_image(starting_point="content", content_image=content_image) preprocessor = gatys_et_al_2017_preprocessor().to(device) postprocessor = gatys_et_al_2017_postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) criterion.set_style_image(preprocessor(style_image)) return default_image_pyramid_optim_loop( input_image, criterion, pyramid, get_optimizer=gatys_et_al_2017_optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, logger=logger, log_fn=log_fn, )
def gatys_ecker_bethge_2015_nst( content_image: torch.Tensor, style_image: torch.Tensor, num_steps: int = 500, impl_params: bool = True, criterion: Optional[PerceptualLoss] = None, quiet: bool = False, logger: Optional[logging.Logger] = None, log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]], None]] = None, ) -> torch.Tensor: if criterion is None: criterion = gatys_ecker_bethge_2015_perceptual_loss( impl_params=impl_params) device = content_image.device criterion = criterion.to(device) starting_point = "content" if impl_params else "random" input_image = get_input_image(starting_point=starting_point, content_image=content_image) preprocessor = gatys_ecker_bethge_2015_preprocessor().to(device) postprocessor = gatys_ecker_bethge_2015_postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) criterion.set_style_image(preprocessor(style_image)) return default_image_optim_loop( input_image, criterion, get_optimizer=gatys_ecker_bethge_2015_optimizer, num_steps=num_steps, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, logger=logger, log_fn=log_fn, )
# is good practice to resize the images upfront to the largest size the ``pyramid`` # will handle. top_level = pyramid[-1] content_image = top_level.resize_image(content_image) style_image = top_level.resize_image(style_image) ######################################################################################## # As a last preliminary step the previously loaded images are set as targets for the # perceptual loss (``criterion``) and we create the input image. criterion.set_content_image(content_image) criterion.set_style_image(style_image) starting_point = "content" input_image = get_input_image(starting_point, content_image=content_image) show_image(input_image, title="Input image") ######################################################################################## # Finally we run the NST with the # :func:`~pystiche.optim.optim.default_image_pyramid_optim_loop`. If ``get_optimizer`` # is not specified, as is the case here, the # :func:`~pystiche.optim.optim.default_image_optimizer`, i.e. # :class:`~torch.optim.lbfgs.LBFGS` is used. output_image = default_image_pyramid_optim_loop(input_image, criterion, pyramid, logger=demo_logger()) # sphinx_gallery_thumbnail_number = 4
def gatys_et_al_2017_guided_nst( content_image: torch.Tensor, content_guides: Dict[str, torch.Tensor], style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]], impl_params: bool = True, criterion: Optional[GuidedPerceptualLoss] = None, pyramid: Optional[ImagePyramid] = None, quiet: bool = False, logger: Optional[logging.Logger] = None, log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]], None]] = None, ) -> torch.Tensor: regions = set(content_guides.keys()) if regions != set(style_images_and_guides.keys()): # FIXME raise RuntimeError regions = sorted(regions) if criterion is None: criterion = gatys_et_al_2017_guided_perceptual_loss( regions, impl_params=impl_params) if pyramid is None: pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, )) device = content_image.device criterion = criterion.to(device) initial_image_resize = pyramid[-1].resize_image initial_guide_resize = pyramid[-1].resize_guide content_image = initial_image_resize(content_image) content_guides = { region: initial_guide_resize(guide) for region, guide in content_guides.items() } style_images_and_guides = { region: (initial_image_resize(image), initial_guide_resize(guide)) for region, (image, guide) in style_images_and_guides.items() } input_image = get_input_image(starting_point="content", content_image=content_image) preprocessor = gatys_et_al_2017_preprocessor().to(device) postprocessor = gatys_et_al_2017_postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) for region, (image, guide) in style_images_and_guides.items(): criterion.set_style_guide(region, guide) criterion.set_style_image(region, preprocessor(image)) for region, guide in content_guides.items(): criterion.set_content_guide(region, guide) return default_image_pyramid_optim_loop( input_image, criterion, pyramid, get_optimizer=gatys_et_al_2017_optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, logger=logger, log_fn=log_fn, )
def guided_nst( content_image: torch.Tensor, content_guides: Dict[str, torch.Tensor], style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]], impl_params: bool = True, hyper_parameters: Optional[HyperParameters] = None, quiet: bool = False, ) -> torch.Tensor: r"""Guided NST from :cite:`GEB+2017`. Args: content_image: Content image for the guided NST. content_guides: Content image guides for the guided NST. style_images_and_guides: Dictionary with the style images and the corresponding guides for each region. impl_params: Switch the behavior and hyper-parameters between the reference implementation of the original authors and what is described in the paper. For details see :ref:`here <gatys_et_al_2017-impl_params>`. hyper_parameters: If omitted, :func:`~pystiche_papers.gatys_et_al_2017.hyper_parameters` is used. quiet: If ``True``, not information is logged during the optimization. Defaults to ``False``. """ regions = set(content_guides.keys()) if regions != set(style_images_and_guides.keys()): # FIXME raise RuntimeError regions = sorted(regions) if hyper_parameters is None: hyper_parameters = _hyper_parameters() device = content_image.device criterion = guided_perceptual_loss(regions, impl_params=impl_params, hyper_parameters=hyper_parameters) criterion = criterion.to(device) image_pyramid = _image_pyramid(hyper_parameters=hyper_parameters, resize_targets=(criterion, )) initial_image_resize = image_pyramid[-1].resize_image initial_guide_resize = image_pyramid[-1].resize_guide content_image = initial_image_resize(content_image) content_guides = { region: initial_guide_resize(guide) for region, guide in content_guides.items() } style_images_and_guides = { region: (initial_image_resize(image), initial_guide_resize(guide)) for region, (image, guide) in style_images_and_guides.items() } input_image = misc.get_input_image(starting_point="content", content_image=content_image) preprocessor = _preprocessor().to(device) postprocessor = _postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) for region, guide in content_guides.items(): criterion.set_content_guide(guide, region=region) for region, (image, guide) in style_images_and_guides.items(): criterion.set_style_image(preprocessor(image), guide=guide, region=region) return optim.pyramid_image_optimization( input_image, criterion, image_pyramid, get_optimizer=optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, )