def test_model_optimization_image_loader(transformer, criterion, test_image, supervised): image_loader = data.DataLoader(Dataset(test_image, supervised=supervised)) optim.model_optimization( image_loader, transformer, criterion, criterion_update_fn=lambda input_image, criterion: None, ) transformer.assert_called_once_with(test_image)
def test_model_default_optimization_criterion_update_fn( transformer, test_image, ): image_loader = data.DataLoader(Dataset(test_image)) content_loss = MSEOperator() style_loss = MSEOperator() criterion = loss.PerceptualLoss(content_loss, style_loss) content_loss.set_target_image(torch.rand_like(test_image)) style_loss.set_target_image(torch.rand_like(test_image)) optim.model_optimization(image_loader, transformer, criterion) ptu.assert_allclose(content_loss.target_image, test_image)
def test_default_transformer_optim_loop(optim_asset_loader): asset = optim_asset_loader("default_transformer_optim_loop") image_loader = asset.input.image_loader criterion = asset.input.perceptual_loss make_torch_ge_1_6_compatible(image_loader, criterion) transformer = asset.input.transformer optimizer = asset.params.get_optimizer(transformer) transformer = optim.model_optimization( image_loader, transformer, criterion, asset.input.criterion_update_fn, optimizer=optimizer, quiet=True, ) actual = tuple(transformer.parameters()) desired = tuple(asset.output.transformer.parameters()) ptu.assert_allclose(actual, desired, rtol=1e-4)
def test_model_optimization_criterion_update_fn_error(image_loader, transformer, criterion): with pytest.raises(RuntimeError): optim.model_optimization(image_loader, transformer, criterion)
def training( content_image_loader: DataLoader, style_image: Union[str, torch.Tensor], impl_params: bool = True, instance_norm: Optional[bool] = None, hyper_parameters: Optional[HyperParameters] = None, quiet: bool = False, ) -> nn.Module: r"""Training a transformer for the NST. Args: content_image_loader: Content images used as input for the ``transformer``. style_image: Style image on which the ``transformer`` should be trained. If ``str``, the image is read from :func:`~pystiche_papers.johnson_alahi_li_2016.images`. impl_params: If ``True``, uses the parameters used in the reference implementation of the original authors rather than what is described in the paper. For details see below. instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than :class:`~torch.nn.BatchNorm2d` as described in the paper. If omitted, defaults to ``impl_params``. hyper_parameters: If omitted, :func:`~pystiche_papers.johnson_alahi_li_2016.hyper_parameters` is used. quiet: If ``True``, not information is logged during the optimization. Defaults to ``False``. If ``impl_params is True`` , an external preprocessing of the images is used. """ if isinstance(style_image, torch.Tensor): device = style_image.device else: device = misc.get_device() images = _images() style_image = images[style_image].read(device=device) if instance_norm is None: instance_norm = impl_params if hyper_parameters is None: hyper_parameters = _hyper_parameters() transformer = _transformer(impl_params=impl_params, instance_norm=instance_norm) transformer = transformer.train().to(device) criterion = perceptual_loss(impl_params=impl_params, hyper_parameters=hyper_parameters) criterion = criterion.eval().to(device) optimizer = _optimizer(transformer) style_transform = _style_transform(hyper_parameters=hyper_parameters) style_transform = style_transform.to(device) style_image = style_transform(style_image) style_image = batch_up_image(style_image, loader=content_image_loader) preprocessor = _preprocessor() preprocessor = preprocessor.to(device) style_image = preprocessor(style_image) criterion.set_style_image(style_image) def criterion_update_fn(input_image: torch.Tensor, criterion: nn.Module) -> None: cast(loss.PerceptualLoss, criterion).set_content_image(input_image) return optim.model_optimization( content_image_loader, transformer, criterion, criterion_update_fn, optimizer=optimizer, quiet=quiet, )