예제 #1
0
    def test_default_image_optim_log_fn_loss_dict_smoke(self):
        class MockOptimLogger:
            def __init__(self):
                self.msg = None

            @contextlib.contextmanager
            def environment(self, header):
                yield

            def message(self, msg):
                self.msg = msg

        loss_dict = pystiche.LossDict(
            (("a", torch.tensor(0.0)), ("b.c", torch.tensor(1.0))))

        log_freq = 1
        max_depth = 1
        optim_logger = MockOptimLogger()
        log_fn = optim.default_image_optim_log_fn(optim_logger,
                                                  log_freq=log_freq,
                                                  max_depth=max_depth)

        step = log_freq
        log_fn(step, loss_dict)

        actual = optim_logger.msg
        desired = loss_dict.format(max_depth=max_depth)
        self.assertEqual(actual, desired)
예제 #2
0
def test_default_image_optim_log_fn_other():
    optim_logger = optim.OptimLogger()
    log_freq = 1
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq)

    with pytest.raises(TypeError):
        step = log_freq
        loss = None
        log_fn(step, loss)
예제 #3
0
    def test_default_image_optim_loop_logging_smoke(self):
        asset = self.load_asset(path.join("optim", "default_image_optim_loop"))

        num_steps = 1
        optim_logger = optim.OptimLogger()
        log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1)
        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_image_optim_loop(
                asset.input.image,
                asset.input.criterion,
                num_steps=num_steps,
                log_fn=log_fn,
            )
예제 #4
0
def test_default_image_optim_loop_logging_smoke(caplog, optim_asset_loader):
    asset = optim_asset_loader("default_image_optim_loop")

    num_steps = 1
    optim_logger = optim.OptimLogger()
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=1)
    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_image_optim_loop(
            asset.input.image,
            asset.input.criterion,
            num_steps=num_steps,
            log_fn=log_fn,
        )
예제 #5
0
def test_default_image_pyramid_optim_loop_logging_smoke(
        caplog, optim_asset_loader):
    asset = optim_asset_loader("default_image_pyramid_optim_loop")

    optim_logger = optim.OptimLogger()
    log_freq = max(level.num_steps
                   for level in asset.input.pyramid._levels) + 1
    log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq)

    with asserts.assert_logs(caplog, logger=optim_logger):
        optim.default_image_pyramid_optim_loop(
            asset.input.image,
            asset.input.criterion,
            asset.input.pyramid,
            logger=optim_logger,
            log_fn=log_fn,
        )
예제 #6
0
    def test_default_image_pyramid_optim_loop_logging_smoke(self):
        asset = self.load_asset(
            path.join("optim", "default_image_pyramid_optim_loop"))

        optim_logger = optim.OptimLogger()
        log_freq = max(
            [level.num_steps for level in asset.input.pyramid._levels]) + 1
        log_fn = optim.default_image_optim_log_fn(optim_logger,
                                                  log_freq=log_freq)

        with self.assertLogs(optim_logger.logger, "INFO"):
            optim.default_image_pyramid_optim_loop(
                asset.input.image,
                asset.input.criterion,
                asset.input.pyramid,
                logger=optim_logger,
                log_fn=log_fn,
            )