def test_default_image_pyramid_optim_loop_logging_smoke( caplog, optim_asset_loader): asset = optim_asset_loader("default_image_pyramid_optim_loop") optim_logger = optim.OptimLogger() log_freq = max(level.num_steps for level in asset.input.pyramid._levels) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop_logging_smoke(self): asset = self.load_asset( path.join("optim", "default_image_pyramid_optim_loop")) optim_logger = optim.OptimLogger() log_freq = max( [level.num_steps for level in asset.input.pyramid._levels]) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop(optim_asset_loader): asset = optim_asset_loader("default_image_pyramid_optim_loop") actual = optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, get_optimizer=asset.params.get_optimizer, quiet=True, ) desired = asset.output.image ptu.assert_allclose(actual, desired, rtol=1e-4)
def test_default_image_pyramid_optim_loop(self): asset = self.load_asset( path.join("optim", "default_image_pyramid_optim_loop")) actual = optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, get_optimizer=asset.params.get_optimizer, quiet=True, ) desired = asset.output.image self.assertTensorAlmostEqual(actual, desired, rtol=1e-4)
def gatys_et_al_2017_nst( content_image: torch.Tensor, style_image: torch.Tensor, impl_params: bool = True, criterion: Optional[PerceptualLoss] = None, pyramid: Optional[ImagePyramid] = None, quiet: bool = False, logger: Optional[logging.Logger] = None, log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]], None]] = None, ) -> torch.Tensor: if criterion is None: criterion = gatys_et_al_2017_perceptual_loss(impl_params=impl_params) if pyramid is None: pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, )) device = content_image.device criterion = criterion.to(device) initial_resize = pyramid[-1].resize_image content_image = initial_resize(content_image) style_image = initial_resize(style_image) input_image = get_input_image(starting_point="content", content_image=content_image) preprocessor = gatys_et_al_2017_preprocessor().to(device) postprocessor = gatys_et_al_2017_postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) criterion.set_style_image(preprocessor(style_image)) return default_image_pyramid_optim_loop( input_image, criterion, pyramid, get_optimizer=gatys_et_al_2017_optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, logger=logger, log_fn=log_fn, )
top_level = pyramid[-1] content_image = top_level.resize_image(content_image) style_image = top_level.resize_image(style_image) ######################################################################################## # As a last preliminary step the previously loaded images are set as targets for the # perceptual loss (``criterion``) and we create the input image. criterion.set_content_image(content_image) criterion.set_style_image(style_image) starting_point = "content" input_image = get_input_image(starting_point, content_image=content_image) show_image(input_image, title="Input image") ######################################################################################## # Finally we run the NST with the # :func:`~pystiche.optim.optim.default_image_pyramid_optim_loop`. If ``get_optimizer`` # is not specified, as is the case here, the # :func:`~pystiche.optim.optim.default_image_optimizer`, i.e. # :class:`~torch.optim.lbfgs.LBFGS` is used. output_image = default_image_pyramid_optim_loop(input_image, criterion, pyramid, logger=demo_logger()) # sphinx_gallery_thumbnail_number = 4 show_image(output_image, title="Output image")
def gatys_et_al_2017_guided_nst( content_image: torch.Tensor, content_guides: Dict[str, torch.Tensor], style_images_and_guides: Dict[str, Tuple[torch.Tensor, torch.Tensor]], impl_params: bool = True, criterion: Optional[GuidedPerceptualLoss] = None, pyramid: Optional[ImagePyramid] = None, quiet: bool = False, logger: Optional[logging.Logger] = None, log_fn: Optional[Callable[[int, Union[torch.Tensor, pystiche.LossDict]], None]] = None, ) -> torch.Tensor: regions = set(content_guides.keys()) if regions != set(style_images_and_guides.keys()): # FIXME raise RuntimeError regions = sorted(regions) if criterion is None: criterion = gatys_et_al_2017_guided_perceptual_loss( regions, impl_params=impl_params) if pyramid is None: pyramid = gatys_et_al_2017_image_pyramid(resize_targets=(criterion, )) device = content_image.device criterion = criterion.to(device) initial_image_resize = pyramid[-1].resize_image initial_guide_resize = pyramid[-1].resize_guide content_image = initial_image_resize(content_image) content_guides = { region: initial_guide_resize(guide) for region, guide in content_guides.items() } style_images_and_guides = { region: (initial_image_resize(image), initial_guide_resize(guide)) for region, (image, guide) in style_images_and_guides.items() } input_image = get_input_image(starting_point="content", content_image=content_image) preprocessor = gatys_et_al_2017_preprocessor().to(device) postprocessor = gatys_et_al_2017_postprocessor().to(device) criterion.set_content_image(preprocessor(content_image)) for region, (image, guide) in style_images_and_guides.items(): criterion.set_style_guide(region, guide) criterion.set_style_image(region, preprocessor(image)) for region, guide in content_guides.items(): criterion.set_content_guide(region, guide) return default_image_pyramid_optim_loop( input_image, criterion, pyramid, get_optimizer=gatys_et_al_2017_optimizer, preprocessor=preprocessor, postprocessor=postprocessor, quiet=quiet, logger=logger, log_fn=log_fn, )