def test_invertiblegradupdate():
    torch.manual_seed(42)

    def grad_fun(x, z):
        return x - z

    x = torch.randn(10, 64, 28, 28)
    y = torch.randn(10, 64, 28, 28)
    z = torch.randn(10, 16, 28, 28)

    model, model_i2l = create_model_and_i2l_copy(InvertibleGradUpdate,
                                                 grad_fun, z.size(1))

    x_est = forward_reverse(model, x, z)
    gradients_forward, gradients_reverse = model_gradients(model, x, y, z)
    gradients_forward_i2l, gradients_reverse_i2l = model_gradients(
        model_i2l, x, y, z)

    assert_allclose(x_est, x)

    for g_est, g in zip(gradients_forward_i2l, gradients_forward):
        assert_allclose(g_est, g)

    for g_est, g in zip(gradients_reverse_i2l, gradients_reverse):
        assert_allclose(g_est, g)
def test_irim():
    torch.manual_seed(42)

    def grad_fun(x, z):
        grad = (x - z)
        grad = grad / x.norm(2, dim=(-2, -1), keepdim=True)
        return grad

    x = torch.randn(10, 64, 28, 28)
    y = torch.randn(10, 64, 28, 28)
    z = torch.randn(10, 16, 28, 28)

    def get_model():
        unets = torch.nn.ModuleList(
            [InvertibleUnet([64, 32, 16], [32, 8, 24], [1, 2, 4])] * 10)
        model = IRIM(unets, grad_fun, n_channels=z.size(1))
        return model

    model, model_i2l = create_model_and_i2l_copy(get_model)

    x_est = forward_reverse(model, x, z)
    gradients_forward, gradients_reverse = model_gradients(model, x, y, z)
    gradients_forward_i2l, gradients_reverse_i2l = model_gradients(
        model_i2l, x, y, z)

    assert_allclose(x_est, x, rtol=1e-3, atol=1e-4)

    for g_est, g in zip(gradients_forward_i2l, gradients_forward):
        assert_allclose(g_est, g)

    for g_est, g in zip(gradients_reverse_i2l, gradients_reverse):
        assert_allclose(g_est, g)

    with torch.enable_grad():
        z.detach_().requires_grad_(True)
        y_est = model.forward(x, z)
        loss = torch.nn.functional.mse_loss(y_est, y)
        grad_z = torch.autograd.grad(loss, z)[0]

    with torch.enable_grad():
        z.detach_().requires_grad_(True)
        x.detach_().requires_grad_(True)

        y_est = model_i2l.forward(x, z)
        loss = torch.nn.functional.mse_loss(y_est, y)
        grad_z_i2l = torch.autograd.grad(loss, z)[0]

    assert_allclose(grad_z, grad_z_i2l)
def test_householder1x1():
    torch.manual_seed(42)
    x = torch.randn(10, 64, 28, 28)
    y = torch.randn(10, 64, 28, 28)

    model, model_i2l = create_model_and_i2l_copy(Housholder1x1, 64, 32, 2)

    x_est = forward_reverse(model, x)
    gradients_forward, gradients_reverse = model_gradients(model, x, y)
    gradients_forward_i2l, gradients_reverse_i2l = model_gradients(
        model_i2l, x, y)

    assert_allclose(x_est, x)

    for g_est, g in zip(gradients_forward_i2l, gradients_forward):
        assert_allclose(g_est, g)

    for g_est, g in zip(gradients_reverse_i2l, gradients_reverse):
        assert_allclose(g_est, g)
def test_revnetlayer():
    torch.manual_seed(42)
    x = torch.randn(10, 64, 28, 28)
    y = torch.randn(10, 64, 28, 28)

    model, model_i2l = create_model_and_i2l_copy(RevNetLayer,
                                                 64,
                                                 32,
                                                 conv_nd=2)

    x_est = forward_reverse(model, x)
    gradients_forward, gradients_reverse = model_gradients(model, x, y)
    gradients_forward_i2l, gradients_reverse_i2l = model_gradients(
        model_i2l, x, y)

    assert_allclose(x_est, x)

    for g_est, g in zip(gradients_forward_i2l, gradients_forward):
        assert_allclose(g_est, g)

    for g_est, g in zip(gradients_reverse_i2l, gradients_reverse):
        assert_allclose(g_est, g)
Beispiel #5
0
def test_invertibleunet():
    torch.manual_seed(42)
    x = torch.randn(10, 64, 28, 28)
    y = torch.randn(10, 64, 28, 28)

    depth = 10
    model, model_i2l = create_model_and_i2l_copy(InvertibleUnet,
                                                 [64, 32, 16] * depth,
                                                 [32, 8, 24] * depth,
                                                 [1, 2, 4] * depth)

    x_est = forward_reverse(model, x)
    gradients_forward, gradients_reverse = model_gradients(model, x, y)
    gradients_forward_i2l, gradients_reverse_i2l = model_gradients(
        model_i2l, x, y)

    assert_allclose(x_est, x)

    for g_est, g in zip(gradients_forward_i2l, gradients_forward):
        assert_allclose(g_est, g)

    for g_est, g in zip(gradients_reverse_i2l, gradients_reverse):
        assert_allclose(g_est, g)