Ejemplo n.º 1
0
    def test_default_transformer_optim_loop_logging_smoke(self):
        asset = self.load_asset(
            path.join("optim", "default_transformer_optim_loop"))

        image_loader = asset.input.image_loader
        optim_logger = optim.OptimLogger()
        log_fn = optim.default_transformer_optim_log_fn(optim_logger,
                                                        len(image_loader),
                                                        log_freq=1)

        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_transformer_optim_loop(
                image_loader,
                asset.input.transformer,
                asset.input.criterion,
                asset.input.criterion_update_fn,
                logger=optim_logger,
                log_fn=log_fn,
            )
Ejemplo n.º 2
0
def test_default_transformer_optim_loop_logging_smoke(caplog,
                                                      optim_asset_loader):
    asset = optim_asset_loader("default_transformer_optim_loop")

    image_loader = asset.input.image_loader
    criterion = asset.input.criterion
    make_torch_ge_1_6_compatible(image_loader, criterion)

    optim_logger = optim.OptimLogger()
    log_fn = optim.default_transformer_optim_log_fn(optim_logger,
                                                    len(image_loader),
                                                    log_freq=1)

    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_transformer_optim_loop(
            image_loader,
            asset.input.transformer,
            criterion,
            asset.input.criterion_update_fn,
            logger=optim_logger,
            log_fn=log_fn,
        )
Ejemplo n.º 3
0
    def test_default_transformer_optim_loop(self):
        asset = self.load_asset(
            path.join("optim", "default_transformer_optim_loop"))

        transformer = asset.input.transformer
        optimizer = asset.params.get_optimizer(transformer)
        transformer = optim.default_transformer_optim_loop(
            asset.input.image_loader,
            transformer,
            asset.input.criterion,
            asset.input.criterion_update_fn,
            optimizer=optimizer,
            quiet=True,
        )

        actual = transformer.parameters()
        desired = asset.output.transformer.parameters()
        self.assertTensorSequenceAlmostEqual(actual, desired, rtol=1e-4)
Ejemplo n.º 4
0
def test_default_transformer_optim_loop(optim_asset_loader):
    asset = optim_asset_loader("default_transformer_optim_loop")

    image_loader = asset.input.image_loader
    criterion = asset.input.criterion
    make_torch_ge_1_6_compatible(image_loader, criterion)

    transformer = asset.input.transformer
    optimizer = asset.params.get_optimizer(transformer)
    transformer = optim.default_transformer_optim_loop(
        image_loader,
        transformer,
        criterion,
        asset.input.criterion_update_fn,
        optimizer=optimizer,
        quiet=True,
    )

    actual = tuple(transformer.parameters())
    desired = tuple(asset.output.transformer.parameters())
    ptu.assert_allclose(actual, desired, rtol=1e-4)
Ejemplo n.º 5
0
def johnson_alahi_li_2016_training(
    content_image_loader: DataLoader,
    style: Union[str, torch.Tensor],
    impl_params=True,
    instance_norm: bool = True,
    transformer: Optional[JohnsonAlahiLi2016Transformer] = None,
    criterion: Optional[PerceptualLoss] = None,
    get_optimizer: Optional[Callable[[JohnsonAlahiLi2016Transformer],
                                     Optimizer]] = None,
    quiet: bool = False,
    logger: Optional[OptimLogger] = None,
    log_fn: Optional[
        Callable[[int, Union[torch.Tensor, pystiche.LossDict], float, float],
                 None]] = None,
):
    if isinstance(style, str):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        images = johnson_alahi_li_2016_images(download=False)
        style_image = images[style].read(device=device)
    else:
        style_image = style
        device = style_image.device
        style = None

    if impl_params:
        preprocessor = johnson_alahi_li_2016_preprocessor()
        preprocessor = preprocessor.to(device)
        style_image = preprocessor(style_image)

    if transformer is None:
        transformer = johnson_alahi_li_2016_transformer(
            impl_params=impl_params, instance_norm=instance_norm)
        transformer = transformer.train()
    transformer = transformer.to(device)

    if criterion is None:
        criterion = johnson_alahi_li_2016_perceptual_loss(
            impl_params=impl_params, instance_norm=instance_norm, style=style)
        criterion = criterion.eval()
    criterion = criterion.to(device)

    if get_optimizer is None:
        get_optimizer = johnson_alahi_li_2016_optimizer

    style_transform = johnson_alahi_li_2016_style_transform(
        impl_params=impl_params, instance_norm=instance_norm, style=style)
    style_transform = style_transform.to(device)
    style_image = style_transform(style_image)
    style_image = batch_up_image(style_image, loader=content_image_loader)

    criterion.set_style_image(style_image)

    def criterion_update_fn(input_image, criterion):
        criterion.set_content_image(input_image)

    default_transformer_optim_loop(
        content_image_loader,
        device,
        transformer,
        criterion,
        criterion_update_fn,
        optimizer=get_optimizer,
        quiet=quiet,
        logger=logger,
        log_fn=log_fn,
    )

    return transformer