예제 #1
0
    def test_default_image_optimizer(self):
        torch.manual_seed(0)
        image = torch.rand(1, 3, 128, 128)
        optimizer = optim.default_image_optimizer(image)

        self.assertIsInstance(optimizer, Optimizer)

        actual = optimizer.param_groups[0]["params"][0]
        desired = image
        self.assertTensorAlmostEqual(actual, desired)
예제 #2
0
def test_default_image_optimizer():
    torch.manual_seed(0)
    image = torch.rand(1, 3, 128, 128)
    optimizer = optim.default_image_optimizer(image)

    assert isinstance(optimizer, torch.optim.Optimizer)

    actual = optimizer.param_groups[0]["params"][0]
    desired = image
    ptu.assert_allclose(actual, desired)
예제 #3
0
def test_image_optimization_optimizer_preprocessor():
    input_image = torch.empty(1)
    criterion = nn.Module()
    optimizer = optim.default_image_optimizer(input_image)
    preprocessor = nn.Module()

    with pytest.raises(RuntimeError):
        optim.image_optimization(input_image,
                                 criterion,
                                 optimizer,
                                 preprocessor=preprocessor)