def test_equal_reference():
    clss = (TxRef, RxRef)
    refs = (dict(forward=txref_forward_reference,
                 adjoint_x=txref_adjoint_x_reference,
            ),
            dict(forward=rxref_forward_reference,
                 adjoint_x=rxref_adjoint_x_reference,
            ),
           )
    Ls = (13, 13, 13)
    Ms = (37, 37, 10)
    Ns = (13, 64, 27)
    Rs = (1, 2, 3)
    sdtypes = (np.float32, np.complex128)

    np.random.seed(1)

    for (cls, ref), (L, M, N), R, sdtype in itertools.product(
        zip(clss, refs), zip(Ls, Ms, Ns), Rs, sdtypes
    ):
        model = cls(L=L, M=M, N=N, R=R, precision=sdtype)

        # forward operation
        fun = model.forward
        reffun = ref['forward']
        s = get_random_oncircle(model.sshape, model.sdtype)
        x = get_random_uniform(model.xshape, model.xydtype)
        y = np.zeros(model.yshape, model.xydtype)
        callable_test = check_equal_reference(fun, reffun, (s, x, y))
        callable_test.description = callable_test.description.format(
            L, M, N, R, np.dtype(sdtype).str
        )
        yield callable_test

        # adjoint_x operation
        fun = model.adjoint_x
        reffun = ref['adjoint_x']
        y = get_random_uniform(model.yshape, model.xydtype)
        s = get_random_oncircle(model.sshape, model.sdtype)
        x = np.zeros(model.xshape, model.xydtype)
        callable_test = check_equal_reference(fun, reffun, (y, s, x))
        callable_test.description = callable_test.description.format(
            L, M, N, R, np.dtype(sdtype).str
        )
        yield callable_test
def check_adjointness(cls, L, M, N, R, sdtype):
    s = get_random_oncircle((L,), sdtype)
    s = s/np.linalg.norm(s)

    model = cls(L=L, M=M, N=N, R=R, precision=sdtype)
    op = model(s=s)

    err_msg = '{0} and {1} are not adjoints, with max error of {2}'

    def call():
        errs = adjointness_error(op, its=100)
        np.testing.assert_array_almost_equal(
            errs, 0, err_msg=err_msg.format(
                op.A.__name__, op.As.__name__, np.max(np.abs(errs)),
            )
        )

    call.description = '{5}: L={0}, M={1}, N={2}, R={3}, {4}'.format(
        L, M, N, R, np.dtype(sdtype).str, cls.__name__,
    )

    return call
def check_opnorm(cls, L, M, N, R, sdtype):
    s = get_random_oncircle((L,), sdtype)
    s = s/np.linalg.norm(s)

    model = cls(L=L, M=M, N=N, R=R, precision=sdtype)
    op = model(s=s)

    true_Anorm = 1
    true_Asnorm = 1

    err_msg = 'Estimated {0} norm ({1}) does not match true {0} norm ({2})'

    def call():
        Anorm, Asnorm, v = opnorm(op, reltol=1e-10, abstol=1e-8, maxits=100)
        np.testing.assert_allclose(Anorm, true_Anorm, rtol=1e-4, atol=1e-2,
            err_msg=err_msg.format('forward', Anorm, true_Anorm))
        np.testing.assert_allclose(Asnorm, true_Asnorm, rtol=1e-4, atol=1e-2,
            err_msg=err_msg.format('adjoint', Asnorm, true_Asnorm))

    call.description = '{5}: L={0}, M={1}, N={2}, R={3}, {4}'.format(
        L, M, N, R, np.dtype(sdtype).str, cls.__name__,
    )

    return call