def test_default_image_optim_log_fn_other(): optim_logger = optim.OptimLogger() log_freq = 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with pytest.raises(TypeError): step = log_freq loss = None log_fn(step, loss)
def test_default_image_optim_loop_logging_smoke(self): asset = self.load_asset(path.join("optim", "default_image_optim_loop")) num_steps = 1 optim_logger = optim.OptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_image_optim_loop( asset.input.image, asset.input.criterion, num_steps=num_steps, log_fn=log_fn, )
def test_default_image_optim_loop_logging_smoke(caplog, optim_asset_loader): asset = optim_asset_loader("default_image_optim_loop") num_steps = 1 optim_logger = optim.OptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_image_optim_loop( asset.input.image, asset.input.criterion, num_steps=num_steps, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop_logging_smoke( caplog, optim_asset_loader): asset = optim_asset_loader("default_image_pyramid_optim_loop") optim_logger = optim.OptimLogger() log_freq = max(level.num_steps for level in asset.input.pyramid._levels) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )
def test_default_image_pyramid_optim_loop_logging_smoke(self): asset = self.load_asset( path.join("optim", "default_image_pyramid_optim_loop")) optim_logger = optim.OptimLogger() log_freq = max( [level.num_steps for level in asset.input.pyramid._levels]) + 1 log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_image_pyramid_optim_loop( asset.input.image, asset.input.criterion, asset.input.pyramid, logger=optim_logger, log_fn=log_fn, )
def test_default_transformer_optim_loop_logging_smoke(self): asset = self.load_asset( path.join("optim", "default_transformer_optim_loop")) image_loader = asset.input.image_loader optim_logger = optim.OptimLogger() log_fn = optim.default_transformer_optim_log_fn(optim_logger, len(image_loader), log_freq=1) with self.assertLogs(optim_logger.logger, "INFO"): optim.default_transformer_optim_loop( image_loader, asset.input.transformer, asset.input.criterion, asset.input.criterion_update_fn, logger=optim_logger, log_fn=log_fn, )
def test_default_transformer_optim_loop_logging_smoke(caplog, optim_asset_loader): asset = optim_asset_loader("default_transformer_optim_loop") image_loader = asset.input.image_loader criterion = asset.input.criterion make_torch_ge_1_6_compatible(image_loader, criterion) optim_logger = optim.OptimLogger() log_fn = optim.default_transformer_optim_log_fn(optim_logger, len(image_loader), log_freq=1) with asserts.assert_logs(caplog, logger=optim_logger): optim.default_transformer_optim_loop( image_loader, asset.input.transformer, criterion, asset.input.criterion_update_fn, logger=optim_logger, log_fn=log_fn, )