Ejemplo n.º 1
0
def test_get_input_image_tensor_style():
    starting_point = "style"
    image = torch.tensor(0.0)

    actual = misc.get_input_image(starting_point, style_image=image)
    desired = image
    assert actual == ptu.approx(desired)

    with pytest.raises(RuntimeError):
        misc.get_input_image(starting_point, content_image=image)
Ejemplo n.º 2
0
def test_get_input_image_tensor():
    image = torch.tensor(0.0)

    starting_point = image
    actual = misc.get_input_image(starting_point)
    desired = image
    assert actual is not desired
    ptu.assert_allclose(actual, desired)
Ejemplo n.º 3
0
def nst(
    content_image: torch.Tensor,
    style_image: torch.Tensor,
    impl_params: bool = True,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> torch.Tensor:
    r"""NST from :cite:`LW2016`.

    Args:
        content_image: Content image for the NST.
        style_image: Style image for the NST.
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <li_wand_2016-impl_params>`.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.li_wand_2016.hyper_parameters` is used.
        quiet: If ``True``, not information is logged during the optimization. Defaults
            to ``False``.
    """
    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters(impl_params=impl_params)

    device = content_image.device

    criterion = perceptual_loss(impl_params=impl_params,
                                hyper_parameters=hyper_parameters)
    criterion = criterion.to(device)

    image_pyramid = _image_pyramid(hyper_parameters=hyper_parameters,
                                   resize_targets=(criterion, ))

    initial_resize = image_pyramid[-1].resize_image
    content_image = initial_resize(content_image)
    style_image = initial_resize(style_image)

    input_image = misc.get_input_image(
        starting_point=hyper_parameters.nst.starting_point,
        content_image=content_image)

    preprocessor = _preprocessor().to(device)
    postprocessor = _postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    criterion.set_style_image(preprocessor(style_image))

    return optim.pyramid_image_optimization(
        input_image,
        criterion,
        image_pyramid,
        get_optimizer=optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
    )
Ejemplo n.º 4
0
def test_get_input_image_tensor_random():
    starting_point = "random"
    content_image = torch.tensor(0.0, dtype=torch.float32)
    style_image = torch.tensor(0.0, dtype=torch.float64)

    actual = misc.get_input_image(starting_point, content_image=content_image)
    desired = content_image
    ptu.assert_tensor_attributes_equal(actual, desired)

    actual = misc.get_input_image(starting_point, style_image=style_image)
    desired = style_image
    ptu.assert_tensor_attributes_equal(actual, desired)

    actual = misc.get_input_image(starting_point,
                                  content_image=content_image,
                                  style_image=style_image)
    desired = content_image
    ptu.assert_tensor_attributes_equal(actual, desired)

    with pytest.raises(RuntimeError):
        misc.get_input_image(starting_point)
Ejemplo n.º 5
0
def nst(
    content_image: torch.Tensor,
    style_image: torch.Tensor,
    impl_params: bool = True,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> torch.Tensor:
    r"""NST from :cite:`GEB2016`.

    Args:
        content_image: Content image for the NST.
        style_image: Style image for the NST.
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <gatys_ecker_bethge_2016-impl_params>`.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.gatys_ecker_bethge_2016.hyper_parameters` is used.
        quiet: If ``True``, no information is logged during the optimization. Defaults
            to ``False``.
    """
    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters()

    device = content_image.device

    criterion = perceptual_loss(impl_params=impl_params,
                                hyper_parameters=hyper_parameters)
    criterion = criterion.to(device)

    input_image = misc.get_input_image(
        starting_point=hyper_parameters.nst.starting_point,
        content_image=content_image)

    preprocessor = _preprocessor().to(device)
    postprocessor = _postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    criterion.set_style_image(preprocessor(style_image))

    torch.autograd.set_detect_anomaly(True)

    return optim.image_optimization(
        input_image,
        criterion,
        optimizer=optimizer,
        num_steps=hyper_parameters.nst.num_steps,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
    )
Ejemplo n.º 6
0
def gatys_et_al_2017_nst(
    content_image: torch.Tensor,
    style_image: torch.Tensor,
    impl_params: bool = True,
    criterion: Optional[PerceptualLoss] = None,
    pyramid: Optional[ImagePyramid] = None,
    quiet: bool = False,
    logger: Optional[logging.Logger] = None,
    log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]],
                              None]] = None,
) -> torch.Tensor:
    if criterion is None:
        criterion = gatys_et_al_2017_perceptual_loss(impl_params=impl_params)

    if pyramid is None:
        pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, ))

    device = content_image.device
    criterion = criterion.to(device)

    initial_resize = pyramid[-1].resize_image
    content_image = initial_resize(content_image)
    style_image = initial_resize(style_image)
    input_image = get_input_image(starting_point="content",
                                  content_image=content_image)

    preprocessor = gatys_et_al_2017_preprocessor().to(device)
    postprocessor = gatys_et_al_2017_postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    criterion.set_style_image(preprocessor(style_image))

    return default_image_pyramid_optim_loop(
        input_image,
        criterion,
        pyramid,
        get_optimizer=gatys_et_al_2017_optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )
Ejemplo n.º 7
0
def gatys_ecker_bethge_2015_nst(
    content_image: torch.Tensor,
    style_image: torch.Tensor,
    num_steps: int = 500,
    impl_params: bool = True,
    criterion: Optional[PerceptualLoss] = None,
    quiet: bool = False,
    logger: Optional[logging.Logger] = None,
    log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]],
                              None]] = None,
) -> torch.Tensor:
    if criterion is None:
        criterion = gatys_ecker_bethge_2015_perceptual_loss(
            impl_params=impl_params)

    device = content_image.device
    criterion = criterion.to(device)

    starting_point = "content" if impl_params else "random"
    input_image = get_input_image(starting_point=starting_point,
                                  content_image=content_image)

    preprocessor = gatys_ecker_bethge_2015_preprocessor().to(device)
    postprocessor = gatys_ecker_bethge_2015_postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    criterion.set_style_image(preprocessor(style_image))

    return default_image_optim_loop(
        input_image,
        criterion,
        get_optimizer=gatys_ecker_bethge_2015_optimizer,
        num_steps=num_steps,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )
# is good practice to resize the images upfront to the largest size the ``pyramid``
# will handle.

top_level = pyramid[-1]
content_image = top_level.resize_image(content_image)
style_image = top_level.resize_image(style_image)

########################################################################################
# As a last preliminary step the previously loaded images are set as targets for the
# perceptual loss (``criterion``) and we create the input image.

criterion.set_content_image(content_image)
criterion.set_style_image(style_image)

starting_point = "content"
input_image = get_input_image(starting_point, content_image=content_image)
show_image(input_image, title="Input image")

########################################################################################
# Finally we run the NST with the
# :func:`~pystiche.optim.optim.default_image_pyramid_optim_loop`. If ``get_optimizer``
# is not specified, as is the case here, the
# :func:`~pystiche.optim.optim.default_image_optimizer`, i.e.
# :class:`~torch.optim.lbfgs.LBFGS` is used.

output_image = default_image_pyramid_optim_loop(input_image,
                                                criterion,
                                                pyramid,
                                                logger=demo_logger())

# sphinx_gallery_thumbnail_number = 4
Ejemplo n.º 9
0
def gatys_et_al_2017_guided_nst(
    content_image: torch.Tensor,
    content_guides: Dict[str, torch.Tensor],
    style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    impl_params: bool = True,
    criterion: Optional[GuidedPerceptualLoss] = None,
    pyramid: Optional[ImagePyramid] = None,
    quiet: bool = False,
    logger: Optional[logging.Logger] = None,
    log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]],
                              None]] = None,
) -> torch.Tensor:
    regions = set(content_guides.keys())
    if regions != set(style_images_and_guides.keys()):
        # FIXME
        raise RuntimeError
    regions = sorted(regions)

    if criterion is None:
        criterion = gatys_et_al_2017_guided_perceptual_loss(
            regions, impl_params=impl_params)

    if pyramid is None:
        pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, ))

    device = content_image.device
    criterion = criterion.to(device)

    initial_image_resize = pyramid[-1].resize_image
    initial_guide_resize = pyramid[-1].resize_guide

    content_image = initial_image_resize(content_image)
    content_guides = {
        region: initial_guide_resize(guide)
        for region, guide in content_guides.items()
    }
    style_images_and_guides = {
        region: (initial_image_resize(image), initial_guide_resize(guide))
        for region, (image, guide) in style_images_and_guides.items()
    }
    input_image = get_input_image(starting_point="content",
                                  content_image=content_image)

    preprocessor = gatys_et_al_2017_preprocessor().to(device)
    postprocessor = gatys_et_al_2017_postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))

    for region, (image, guide) in style_images_and_guides.items():
        criterion.set_style_guide(region, guide)
        criterion.set_style_image(region, preprocessor(image))

    for region, guide in content_guides.items():
        criterion.set_content_guide(region, guide)

    return default_image_pyramid_optim_loop(
        input_image,
        criterion,
        pyramid,
        get_optimizer=gatys_et_al_2017_optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )
Ejemplo n.º 10
0
def guided_nst(
    content_image: torch.Tensor,
    content_guides: Dict[str, torch.Tensor],
    style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    impl_params: bool = True,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> torch.Tensor:
    r"""Guided NST from :cite:`GEB+2017`.

    Args:
        content_image: Content image for the guided NST.
        content_guides: Content image guides for the guided NST.
        style_images_and_guides: Dictionary with the style images and the corresponding
            guides for each region.
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <gatys_et_al_2017-impl_params>`.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.gatys_et_al_2017.hyper_parameters` is used.
        quiet: If ``True``, not information is logged during the optimization. Defaults
            to ``False``.
    """
    regions = set(content_guides.keys())
    if regions != set(style_images_and_guides.keys()):
        # FIXME
        raise RuntimeError
    regions = sorted(regions)

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters()

    device = content_image.device

    criterion = guided_perceptual_loss(regions,
                                       impl_params=impl_params,
                                       hyper_parameters=hyper_parameters)
    criterion = criterion.to(device)

    image_pyramid = _image_pyramid(hyper_parameters=hyper_parameters,
                                   resize_targets=(criterion, ))
    initial_image_resize = image_pyramid[-1].resize_image
    initial_guide_resize = image_pyramid[-1].resize_guide

    content_image = initial_image_resize(content_image)
    content_guides = {
        region: initial_guide_resize(guide)
        for region, guide in content_guides.items()
    }
    style_images_and_guides = {
        region: (initial_image_resize(image), initial_guide_resize(guide))
        for region, (image, guide) in style_images_and_guides.items()
    }
    input_image = misc.get_input_image(starting_point="content",
                                       content_image=content_image)

    preprocessor = _preprocessor().to(device)
    postprocessor = _postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    for region, guide in content_guides.items():
        criterion.set_content_guide(guide, region=region)

    for region, (image, guide) in style_images_and_guides.items():
        criterion.set_style_image(preprocessor(image),
                                  guide=guide,
                                  region=region)

    return optim.pyramid_image_optimization(
        input_image,
        criterion,
        image_pyramid,
        get_optimizer=optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
    )