def test_default_image_optim_log_fn_loss_dict_smoke(self): class MockOptimLogger: def __init__(self): self.msg = None @contextlib.contextmanager def environment(self, header): yield def message(self, msg): self.msg = msg loss_dict = pystiche.LossDict( (("a", torch.tensor(0.0)), ("b.c", torch.tensor(1.0)))) log_freq = 1 max_depth = 1 optim_logger = MockOptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq, max_depth=max_depth) step = log_freq log_fn(step, loss_dict) actual = optim_logger.msg desired = loss_dict.format(max_depth=max_depth) self.assertEqual(actual, desired)
def test_default_image_optim_log_fn_other(): optim_logger = optim.OptimLogger() log_freq = 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with pytest.raises(TypeError): step = log_freq loss = None log_fn(step, loss)
def test_default_image_optim_loop_logging_smoke(self): asset = self.load_asset(path.join("optim", "default_image_optim_loop")) num_steps = 1 optim_logger = optim.OptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_image_optim_loop( asset.input.image, asset.input.criterion, num_steps=num_steps, log_fn=log_fn, )
def test_default_image_optim_loop_logging_smoke(caplog, optim_asset_loader): asset = optim_asset_loader("default_image_optim_loop") num_steps = 1 optim_logger = optim.OptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_image_optim_loop( asset.input.image, asset.input.criterion, num_steps=num_steps, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop_logging_smoke( caplog, optim_asset_loader): asset = optim_asset_loader("default_image_pyramid_optim_loop") optim_logger = optim.OptimLogger() log_freq = max(level.num_steps for level in asset.input.pyramid._levels) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop_logging_smoke(self): asset = self.load_asset( path.join("optim", "default_image_pyramid_optim_loop")) optim_logger = optim.OptimLogger() log_freq = max( [level.num_steps for level in asset.input.pyramid._levels]) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )