def test_UnitaryTransform(self): shape = [6] lamda = 1.0 A = linop.FFT(shape) P = prox.UnitaryTransform(prox.L2Reg(shape, lamda), A) x = util.randn(shape) y = P(0.1, x) npt.assert_allclose(y, x / (1 + lamda * 0.1))
def test_FFT(self): for ndim in [1, 2, 3]: for n in [3, 4, 5, 6]: ishape = [n] * ndim A = linop.FFT(ishape) self.check_linop_linear(A) self.check_linop_adjoint(A) self.check_linop_unitary(A) self.check_linop_pickleable(A)
def test_to_pytorch_function_complex(self): A = linop.FFT([3]) x = np.array([1 + 1j, 2 + 2j, 3 + 3j], np.complex) y = np.ones([3], np.complex) with self.subTest('forward'): f = pytorch.to_pytorch_function( A, input_iscomplex=True, output_iscomplex=True).apply x_torch = pytorch.to_pytorch(x) npt.assert_allclose(f(x_torch).detach().numpy().ravel(), A(x).view(np.float)) with self.subTest('adjoint'): y_torch = pytorch.to_pytorch(y) loss = (f(x_torch) - y_torch).pow(2).sum() / 2 loss.backward() npt.assert_allclose(x_torch.grad.detach().numpy().ravel(), A.H(A(x) - y).view(np.float))