Ejemplo n.º 1
0
def test_default_image_pyramid_optim_loop_logging_smoke(
        caplog, optim_asset_loader):
    asset = optim_asset_loader("default_image_pyramid_optim_loop")

    optim_logger = optim.OptimLogger()
    log_freq = max(level.num_steps
                   for level in asset.input.pyramid._levels) + 1
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq)

    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_image_pyramid_optim_loop(
            asset.input.image,
            asset.input.criterion,
            asset.input.pyramid,
            logger=optim_logger,
            log_fn=log_fn,
        )
Ejemplo n.º 2
0
    def test_default_image_pyramid_optim_loop_logging_smoke(self):
        asset = self.load_asset(
            path.join("optim", "default_image_pyramid_optim_loop"))

        optim_logger = optim.OptimLogger()
        log_freq = max(
            [level.num_steps for level in asset.input.pyramid._levels]) + 1
        log_fn = optim.default_image_optim_log_fn(optim_logger,
                                                  log_freq=log_freq)

        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_image_pyramid_optim_loop(
                asset.input.image,
                asset.input.criterion,
                asset.input.pyramid,
                logger=optim_logger,
                log_fn=log_fn,
            )
Ejemplo n.º 3
0
def test_default_image_pyramid_optim_loop(optim_asset_loader):
    asset = optim_asset_loader("default_image_pyramid_optim_loop")

    actual = optim.default_image_pyramid_optim_loop(
        asset.input.image,
        asset.input.criterion,
        asset.input.pyramid,
        get_optimizer=asset.params.get_optimizer,
        quiet=True,
    )
    desired = asset.output.image
    ptu.assert_allclose(actual, desired, rtol=1e-4)
Ejemplo n.º 4
0
    def test_default_image_pyramid_optim_loop(self):
        asset = self.load_asset(
            path.join("optim", "default_image_pyramid_optim_loop"))

        actual = optim.default_image_pyramid_optim_loop(
            asset.input.image,
            asset.input.criterion,
            asset.input.pyramid,
            get_optimizer=asset.params.get_optimizer,
            quiet=True,
        )
        desired = asset.output.image
        self.assertTensorAlmostEqual(actual, desired, rtol=1e-4)
Ejemplo n.º 5
0
def gatys_et_al_2017_nst(
    content_image: torch.Tensor,
    style_image: torch.Tensor,
    impl_params: bool = True,
    criterion: Optional[PerceptualLoss] = None,
    pyramid: Optional[ImagePyramid] = None,
    quiet: bool = False,
    logger: Optional[logging.Logger] = None,
    log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]],
                              None]] = None,
) -> torch.Tensor:
    if criterion is None:
        criterion = gatys_et_al_2017_perceptual_loss(impl_params=impl_params)

    if pyramid is None:
        pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, ))

    device = content_image.device
    criterion = criterion.to(device)

    initial_resize = pyramid[-1].resize_image
    content_image = initial_resize(content_image)
    style_image = initial_resize(style_image)
    input_image = get_input_image(starting_point="content",
                                  content_image=content_image)

    preprocessor = gatys_et_al_2017_preprocessor().to(device)
    postprocessor = gatys_et_al_2017_postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))
    criterion.set_style_image(preprocessor(style_image))

    return default_image_pyramid_optim_loop(
        input_image,
        criterion,
        pyramid,
        get_optimizer=gatys_et_al_2017_optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )
top_level = pyramid[-1]
content_image = top_level.resize_image(content_image)
style_image = top_level.resize_image(style_image)

########################################################################################
# As a last preliminary step the previously loaded images are set as targets for the
# perceptual loss (``criterion``) and we create the input image.

criterion.set_content_image(content_image)
criterion.set_style_image(style_image)

starting_point = "content"
input_image = get_input_image(starting_point, content_image=content_image)
show_image(input_image, title="Input image")

########################################################################################
# Finally we run the NST with the
# :func:`~pystiche.optim.optim.default_image_pyramid_optim_loop`. If ``get_optimizer``
# is not specified, as is the case here, the
# :func:`~pystiche.optim.optim.default_image_optimizer`, i.e.
# :class:`~torch.optim.lbfgs.LBFGS` is used.

output_image = default_image_pyramid_optim_loop(input_image,
                                                criterion,
                                                pyramid,
                                                logger=demo_logger())

# sphinx_gallery_thumbnail_number = 4
show_image(output_image, title="Output image")
Ejemplo n.º 7
0
def gatys_et_al_2017_guided_nst(
    content_image: torch.Tensor,
    content_guides: Dict[str, torch.Tensor],
    style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    impl_params: bool = True,
    criterion: Optional[GuidedPerceptualLoss] = None,
    pyramid: Optional[ImagePyramid] = None,
    quiet: bool = False,
    logger: Optional[logging.Logger] = None,
    log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]],
                              None]] = None,
) -> torch.Tensor:
    regions = set(content_guides.keys())
    if regions != set(style_images_and_guides.keys()):
        # FIXME
        raise RuntimeError
    regions = sorted(regions)

    if criterion is None:
        criterion = gatys_et_al_2017_guided_perceptual_loss(
            regions, impl_params=impl_params)

    if pyramid is None:
        pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, ))

    device = content_image.device
    criterion = criterion.to(device)

    initial_image_resize = pyramid[-1].resize_image
    initial_guide_resize = pyramid[-1].resize_guide

    content_image = initial_image_resize(content_image)
    content_guides = {
        region: initial_guide_resize(guide)
        for region, guide in content_guides.items()
    }
    style_images_and_guides = {
        region: (initial_image_resize(image), initial_guide_resize(guide))
        for region, (image, guide) in style_images_and_guides.items()
    }
    input_image = get_input_image(starting_point="content",
                                  content_image=content_image)

    preprocessor = gatys_et_al_2017_preprocessor().to(device)
    postprocessor = gatys_et_al_2017_postprocessor().to(device)

    criterion.set_content_image(preprocessor(content_image))

    for region, (image, guide) in style_images_and_guides.items():
        criterion.set_style_guide(region, guide)
        criterion.set_style_image(region, preprocessor(image))

    for region, guide in content_guides.items():
        criterion.set_content_guide(region, guide)

    return default_image_pyramid_optim_loop(
        input_image,
        criterion,
        pyramid,
        get_optimizer=gatys_et_al_2017_optimizer,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )