def test_invertiblegradupdate(): torch.manual_seed(42) def grad_fun(x, z): return x - z x = torch.randn(10, 64, 28, 28) y = torch.randn(10, 64, 28, 28) z = torch.randn(10, 16, 28, 28) model, model_i2l = create_model_and_i2l_copy(InvertibleGradUpdate, grad_fun, z.size(1)) x_est = forward_reverse(model, x, z) gradients_forward, gradients_reverse = model_gradients(model, x, y, z) gradients_forward_i2l, gradients_reverse_i2l = model_gradients( model_i2l, x, y, z) assert_allclose(x_est, x) for g_est, g in zip(gradients_forward_i2l, gradients_forward): assert_allclose(g_est, g) for g_est, g in zip(gradients_reverse_i2l, gradients_reverse): assert_allclose(g_est, g)
def test_irim(): torch.manual_seed(42) def grad_fun(x, z): grad = (x - z) grad = grad / x.norm(2, dim=(-2, -1), keepdim=True) return grad x = torch.randn(10, 64, 28, 28) y = torch.randn(10, 64, 28, 28) z = torch.randn(10, 16, 28, 28) def get_model(): unets = torch.nn.ModuleList( [InvertibleUnet([64, 32, 16], [32, 8, 24], [1, 2, 4])] * 10) model = IRIM(unets, grad_fun, n_channels=z.size(1)) return model model, model_i2l = create_model_and_i2l_copy(get_model) x_est = forward_reverse(model, x, z) gradients_forward, gradients_reverse = model_gradients(model, x, y, z) gradients_forward_i2l, gradients_reverse_i2l = model_gradients( model_i2l, x, y, z) assert_allclose(x_est, x, rtol=1e-3, atol=1e-4) for g_est, g in zip(gradients_forward_i2l, gradients_forward): assert_allclose(g_est, g) for g_est, g in zip(gradients_reverse_i2l, gradients_reverse): assert_allclose(g_est, g) with torch.enable_grad(): z.detach_().requires_grad_(True) y_est = model.forward(x, z) loss = torch.nn.functional.mse_loss(y_est, y) grad_z = torch.autograd.grad(loss, z)[0] with torch.enable_grad(): z.detach_().requires_grad_(True) x.detach_().requires_grad_(True) y_est = model_i2l.forward(x, z) loss = torch.nn.functional.mse_loss(y_est, y) grad_z_i2l = torch.autograd.grad(loss, z)[0] assert_allclose(grad_z, grad_z_i2l)
def test_householder1x1(): torch.manual_seed(42) x = torch.randn(10, 64, 28, 28) y = torch.randn(10, 64, 28, 28) model, model_i2l = create_model_and_i2l_copy(Housholder1x1, 64, 32, 2) x_est = forward_reverse(model, x) gradients_forward, gradients_reverse = model_gradients(model, x, y) gradients_forward_i2l, gradients_reverse_i2l = model_gradients( model_i2l, x, y) assert_allclose(x_est, x) for g_est, g in zip(gradients_forward_i2l, gradients_forward): assert_allclose(g_est, g) for g_est, g in zip(gradients_reverse_i2l, gradients_reverse): assert_allclose(g_est, g)
def test_revnetlayer(): torch.manual_seed(42) x = torch.randn(10, 64, 28, 28) y = torch.randn(10, 64, 28, 28) model, model_i2l = create_model_and_i2l_copy(RevNetLayer, 64, 32, conv_nd=2) x_est = forward_reverse(model, x) gradients_forward, gradients_reverse = model_gradients(model, x, y) gradients_forward_i2l, gradients_reverse_i2l = model_gradients( model_i2l, x, y) assert_allclose(x_est, x) for g_est, g in zip(gradients_forward_i2l, gradients_forward): assert_allclose(g_est, g) for g_est, g in zip(gradients_reverse_i2l, gradients_reverse): assert_allclose(g_est, g)
def test_invertibleunet(): torch.manual_seed(42) x = torch.randn(10, 64, 28, 28) y = torch.randn(10, 64, 28, 28) depth = 10 model, model_i2l = create_model_and_i2l_copy(InvertibleUnet, [64, 32, 16] * depth, [32, 8, 24] * depth, [1, 2, 4] * depth) x_est = forward_reverse(model, x) gradients_forward, gradients_reverse = model_gradients(model, x, y) gradients_forward_i2l, gradients_reverse_i2l = model_gradients( model_i2l, x, y) assert_allclose(x_est, x) for g_est, g in zip(gradients_forward_i2l, gradients_forward): assert_allclose(g_est, g) for g_est, g in zip(gradients_reverse_i2l, gradients_reverse): assert_allclose(g_est, g)