def test_default_transformer_optim_loop_logging_smoke(self): asset = self.load_asset( path.join("optim", "default_transformer_optim_loop")) image_loader = asset.input.image_loader optim_logger = optim.OptimLogger() log_fn = optim.default_transformer_optim_log_fn(optim_logger, len(image_loader), log_freq=1) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_transformer_optim_loop( image_loader, asset.input.transformer, asset.input.criterion, asset.input.criterion_update_fn, logger=optim_logger, log_fn=log_fn, )
def test_default_transformer_optim_loop_logging_smoke(caplog, optim_asset_loader): asset = optim_asset_loader("default_transformer_optim_loop") image_loader = asset.input.image_loader criterion = asset.input.criterion make_torch_ge_1_6_compatible(image_loader, criterion) optim_logger = optim.OptimLogger() log_fn = optim.default_transformer_optim_log_fn(optim_logger, len(image_loader), log_freq=1) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_transformer_optim_loop( image_loader, asset.input.transformer, criterion, asset.input.criterion_update_fn, logger=optim_logger, log_fn=log_fn, )
def test_default_transformer_optim_loop(self): asset = self.load_asset( path.join("optim", "default_transformer_optim_loop")) transformer = asset.input.transformer optimizer = asset.params.get_optimizer(transformer) transformer = optim.default_transformer_optim_loop( asset.input.image_loader, transformer, asset.input.criterion, asset.input.criterion_update_fn, optimizer=optimizer, quiet=True, ) actual = transformer.parameters() desired = asset.output.transformer.parameters() self.assertTensorSequenceAlmostEqual(actual, desired, rtol=1e-4)
def test_default_transformer_optim_loop(optim_asset_loader): asset = optim_asset_loader("default_transformer_optim_loop") image_loader = asset.input.image_loader criterion = asset.input.criterion make_torch_ge_1_6_compatible(image_loader, criterion) transformer = asset.input.transformer optimizer = asset.params.get_optimizer(transformer) transformer = optim.default_transformer_optim_loop( image_loader, transformer, criterion, asset.input.criterion_update_fn, optimizer=optimizer, quiet=True, ) actual = tuple(transformer.parameters()) desired = tuple(asset.output.transformer.parameters()) ptu.assert_allclose(actual, desired, rtol=1e-4)
def johnson_alahi_li_2016_training( content_image_loader: DataLoader, style: Union[str, torch.Tensor], impl_params=True, instance_norm: bool = True, transformer: Optional[JohnsonAlahiLi2016Transformer] = None, criterion: Optional[PerceptualLoss] = None, get_optimizer: Optional[Callable[[JohnsonAlahiLi2016Transformer], Optimizer]] = None, quiet: bool = False, logger: Optional[OptimLogger] = None, log_fn: Optional[ Callable[[int, Union[torch.Tensor, pystiche.LossDict], float, float], None]] = None, ): if isinstance(style, str): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") images = johnson_alahi_li_2016_images(download=False) style_image = images[style].read(device=device) else: style_image = style device = style_image.device style = None if impl_params: preprocessor = johnson_alahi_li_2016_preprocessor() preprocessor = preprocessor.to(device) style_image = preprocessor(style_image) if transformer is None: transformer = johnson_alahi_li_2016_transformer( impl_params=impl_params, instance_norm=instance_norm) transformer = transformer.train() transformer = transformer.to(device) if criterion is None: criterion = johnson_alahi_li_2016_perceptual_loss( impl_params=impl_params, instance_norm=instance_norm, style=style) criterion = criterion.eval() criterion = criterion.to(device) if get_optimizer is None: get_optimizer = johnson_alahi_li_2016_optimizer style_transform = johnson_alahi_li_2016_style_transform( impl_params=impl_params, instance_norm=instance_norm, style=style) style_transform = style_transform.to(device) style_image = style_transform(style_image) style_image = batch_up_image(style_image, loader=content_image_loader) criterion.set_style_image(style_image) def criterion_update_fn(input_image, criterion): criterion.set_content_image(input_image) default_transformer_optim_loop( content_image_loader, device, transformer, criterion, criterion_update_fn, optimizer=get_optimizer, quiet=quiet, logger=logger, log_fn=log_fn, ) return transformer