コード例 #1
0
ファイル: test_optim.py プロジェクト: pystiche/pystiche
def test_model_optimization_image_loader(transformer, criterion, test_image,
                                         supervised):
    image_loader = data.DataLoader(Dataset(test_image, supervised=supervised))

    optim.model_optimization(
        image_loader,
        transformer,
        criterion,
        criterion_update_fn=lambda input_image, criterion: None,
    )

    transformer.assert_called_once_with(test_image)
コード例 #2
0
ファイル: test_optim.py プロジェクト: pystiche/pystiche
def test_model_default_optimization_criterion_update_fn(
    transformer,
    test_image,
):
    image_loader = data.DataLoader(Dataset(test_image))

    content_loss = MSEOperator()
    style_loss = MSEOperator()
    criterion = loss.PerceptualLoss(content_loss, style_loss)

    content_loss.set_target_image(torch.rand_like(test_image))
    style_loss.set_target_image(torch.rand_like(test_image))
    optim.model_optimization(image_loader, transformer, criterion)

    ptu.assert_allclose(content_loss.target_image, test_image)
コード例 #3
0
ファイル: test_optim.py プロジェクト: pystiche/pystiche
def test_default_transformer_optim_loop(optim_asset_loader):
    asset = optim_asset_loader("default_transformer_optim_loop")

    image_loader = asset.input.image_loader
    criterion = asset.input.perceptual_loss
    make_torch_ge_1_6_compatible(image_loader, criterion)

    transformer = asset.input.transformer
    optimizer = asset.params.get_optimizer(transformer)
    transformer = optim.model_optimization(
        image_loader,
        transformer,
        criterion,
        asset.input.criterion_update_fn,
        optimizer=optimizer,
        quiet=True,
    )

    actual = tuple(transformer.parameters())
    desired = tuple(asset.output.transformer.parameters())
    ptu.assert_allclose(actual, desired, rtol=1e-4)
コード例 #4
0
ファイル: test_optim.py プロジェクト: pystiche/pystiche
def test_model_optimization_criterion_update_fn_error(image_loader,
                                                      transformer, criterion):
    with pytest.raises(RuntimeError):
        optim.model_optimization(image_loader, transformer, criterion)
コード例 #5
0
ファイル: _nst.py プロジェクト: jbueltemeier/pystiche_papers
def training(
    content_image_loader: DataLoader,
    style_image: Union[str, torch.Tensor],
    impl_params: bool = True,
    instance_norm: Optional[bool] = None,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> nn.Module:
    r"""Training a transformer for the NST.

    Args:
        content_image_loader: Content images used as input for the ``transformer``.
        style_image: Style image on which the ``transformer`` should be trained. If
            ``str``, the image is read from
            :func:`~pystiche_papers.johnson_alahi_li_2016.images`.
        impl_params: If ``True``, uses the parameters used in the reference
            implementation of the original authors rather than what is described in
            the paper. For details see below.
        instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than
            :class:`~torch.nn.BatchNorm2d` as described in the paper. If omitted,
            defaults to ``impl_params``.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.johnson_alahi_li_2016.hyper_parameters` is used.
        quiet: If ``True``, not information is logged during the optimization. Defaults
            to ``False``.

    If ``impl_params is True`` , an external preprocessing of the images is used.

    """
    if isinstance(style_image, torch.Tensor):
        device = style_image.device
    else:
        device = misc.get_device()
        images = _images()
        style_image = images[style_image].read(device=device)

    if instance_norm is None:
        instance_norm = impl_params

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters()

    transformer = _transformer(impl_params=impl_params,
                               instance_norm=instance_norm)
    transformer = transformer.train().to(device)

    criterion = perceptual_loss(impl_params=impl_params,
                                hyper_parameters=hyper_parameters)
    criterion = criterion.eval().to(device)

    optimizer = _optimizer(transformer)

    style_transform = _style_transform(hyper_parameters=hyper_parameters)
    style_transform = style_transform.to(device)
    style_image = style_transform(style_image)
    style_image = batch_up_image(style_image, loader=content_image_loader)

    preprocessor = _preprocessor()
    preprocessor = preprocessor.to(device)
    style_image = preprocessor(style_image)

    criterion.set_style_image(style_image)

    def criterion_update_fn(input_image: torch.Tensor,
                            criterion: nn.Module) -> None:
        cast(loss.PerceptualLoss, criterion).set_content_image(input_image)

    return optim.model_optimization(
        content_image_loader,
        transformer,
        criterion,
        criterion_update_fn,
        optimizer=optimizer,
        quiet=quiet,
    )