def test_to_pytorch_function(self): A = linop.Resize([5], [3]) x = np.array([1, 2, 3], np.float) y = np.ones([5]) with self.subTest('forward'): f = pytorch.to_pytorch_function(A).apply x_torch = pytorch.to_pytorch(x) npt.assert_allclose(f(x_torch).detach().numpy(), A(x)) with self.subTest('adjoint'): y_torch = pytorch.to_pytorch(y) loss = (f(x_torch) - y_torch).pow(2).sum() / 2 loss.backward() npt.assert_allclose(x_torch.grad.detach().numpy(), A.H(A(x) - y))
def test_to_pytorch_complex(self): for dtype in [np.complex64, np.complex128]: for device in devices: with self.subTest(device=device, dtype=dtype): xp = device.xp array = xp.array([1 + 1j, 2 + 2j, 3 + 3j], dtype=dtype) tensor = pytorch.to_pytorch(array) tensor[0, 0] = 0 xp.testing.assert_allclose(array, [1j, 2 + 2j, 3 + 3j])
def test_to_pytorch(self): for dtype in [np.float32, np.float64]: for device in devices: with self.subTest(device=device, dtype=dtype): xp = device.xp array = xp.array([1, 2, 3], dtype=dtype) tensor = pytorch.to_pytorch(array) tensor[0] = 0 xp.testing.assert_allclose(array, [0, 2, 3])
def test_to_pytorch_function_complex(self): A = linop.FFT([3]) x = np.array([1 + 1j, 2 + 2j, 3 + 3j], np.complex) y = np.ones([3], np.complex) with self.subTest('forward'): f = pytorch.to_pytorch_function( A, input_iscomplex=True, output_iscomplex=True).apply x_torch = pytorch.to_pytorch(x) npt.assert_allclose(f(x_torch).detach().numpy().ravel(), A(x).view(np.float)) with self.subTest('adjoint'): y_torch = pytorch.to_pytorch(y) loss = (f(x_torch) - y_torch).pow(2).sum() / 2 loss.backward() npt.assert_allclose(x_torch.grad.detach().numpy().ravel(), A.H(A(x) - y).view(np.float))
def test_to_pytorch_complex(self): for dtype in [np.complex64, np.complex128]: for device in devices: with self.subTest(device=device, dtype=dtype): xp = device.xp array = xp.array([1 + 1j, 2 + 2j, 3 + 3j], dtype=dtype) tensor = pytorch.to_pytorch(array) array[0] = 0 torch.testing.assert_allclose( tensor, torch.tensor([[0, 0], [2, 2], [3, 3]], dtype=tensor.dtype, device=tensor.device))