Exemplo n.º 1
0
 def test_UnitaryTransform(self):
     shape = [6]
     lamda = 1.0
     A = linop.FFT(shape)
     P = prox.UnitaryTransform(prox.L2Reg(shape, lamda), A)
     x = util.randn(shape)
     y = P(0.1, x)
     npt.assert_allclose(y, x / (1 + lamda * 0.1))
Exemplo n.º 2
0
 def test_FFT(self):
     for ndim in [1, 2, 3]:
         for n in [3, 4, 5, 6]:
             ishape = [n] * ndim
             A = linop.FFT(ishape)
             self.check_linop_linear(A)
             self.check_linop_adjoint(A)
             self.check_linop_unitary(A)
             self.check_linop_pickleable(A)
Exemplo n.º 3
0
        def test_to_pytorch_function_complex(self):
            A = linop.FFT([3])
            x = np.array([1 + 1j, 2 + 2j, 3 + 3j], np.complex)
            y = np.ones([3], np.complex)

            with self.subTest('forward'):
                f = pytorch.to_pytorch_function(
                    A,
                    input_iscomplex=True,
                    output_iscomplex=True).apply
                x_torch = pytorch.to_pytorch(x)
                npt.assert_allclose(f(x_torch).detach().numpy().ravel(),
                                    A(x).view(np.float))

            with self.subTest('adjoint'):
                y_torch = pytorch.to_pytorch(y)
                loss = (f(x_torch) - y_torch).pow(2).sum() / 2
                loss.backward()
                npt.assert_allclose(x_torch.grad.detach().numpy().ravel(),
                                    A.H(A(x) - y).view(np.float))