예제 #1
0
def test_default_image_optim_log_fn_other():
    optim_logger = optim.OptimLogger()
    log_freq = 1
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq)

    with pytest.raises(TypeError):
        step = log_freq
        loss = None
        log_fn(step, loss)
예제 #2
0
    def test_default_image_optim_loop_logging_smoke(self):
        asset = self.load_asset(path.join("optim", "default_image_optim_loop"))

        num_steps = 1
        optim_logger = optim.OptimLogger()
        log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1)
        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_image_optim_loop(
                asset.input.image,
                asset.input.criterion,
                num_steps=num_steps,
                log_fn=log_fn,
            )
예제 #3
0
def test_default_image_optim_loop_logging_smoke(caplog, optim_asset_loader):
    asset = optim_asset_loader("default_image_optim_loop")

    num_steps = 1
    optim_logger = optim.OptimLogger()
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1)
    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_image_optim_loop(
            asset.input.image,
            asset.input.criterion,
            num_steps=num_steps,
            log_fn=log_fn,
        )
예제 #4
0
def test_default_image_pyramid_optim_loop_logging_smoke(
        caplog, optim_asset_loader):
    asset = optim_asset_loader("default_image_pyramid_optim_loop")

    optim_logger = optim.OptimLogger()
    log_freq = max(level.num_steps
                   for level in asset.input.pyramid._levels) + 1
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq)

    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_image_pyramid_optim_loop(
            asset.input.image,
            asset.input.criterion,
            asset.input.pyramid,
            logger=optim_logger,
            log_fn=log_fn,
        )
예제 #5
0
    def test_default_image_pyramid_optim_loop_logging_smoke(self):
        asset = self.load_asset(
            path.join("optim", "default_image_pyramid_optim_loop"))

        optim_logger = optim.OptimLogger()
        log_freq = max(
            [level.num_steps for level in asset.input.pyramid._levels]) + 1
        log_fn = optim.default_image_optim_log_fn(optim_logger,
                                                  log_freq=log_freq)

        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_image_pyramid_optim_loop(
                asset.input.image,
                asset.input.criterion,
                asset.input.pyramid,
                logger=optim_logger,
                log_fn=log_fn,
            )
예제 #6
0
    def test_default_transformer_optim_loop_logging_smoke(self):
        asset = self.load_asset(
            path.join("optim", "default_transformer_optim_loop"))

        image_loader = asset.input.image_loader
        optim_logger = optim.OptimLogger()
        log_fn = optim.default_transformer_optim_log_fn(optim_logger,
                                                        len(image_loader),
                                                        log_freq=1)

        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_transformer_optim_loop(
                image_loader,
                asset.input.transformer,
                asset.input.criterion,
                asset.input.criterion_update_fn,
                logger=optim_logger,
                log_fn=log_fn,
            )
예제 #7
0
def test_default_transformer_optim_loop_logging_smoke(caplog,
                                                      optim_asset_loader):
    asset = optim_asset_loader("default_transformer_optim_loop")

    image_loader = asset.input.image_loader
    criterion = asset.input.criterion
    make_torch_ge_1_6_compatible(image_loader, criterion)

    optim_logger = optim.OptimLogger()
    log_fn = optim.default_transformer_optim_log_fn(optim_logger,
                                                    len(image_loader),
                                                    log_freq=1)

    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_transformer_optim_loop(
            image_loader,
            asset.input.transformer,
            criterion,
            asset.input.criterion_update_fn,
            logger=optim_logger,
            log_fn=log_fn,
        )