Ejemplo n.º 1
0
def test_operator_scaling(dom_eq_ran):
    """Check operator scaling against NumPy reference."""
    if dom_eq_ran:
        mat = np.random.rand(3, 3)
    else:
        mat = np.random.rand(4, 3)

    op = MultiplyAndSquareOp(mat)
    xarr, x = noise_elements(op.domain)

    # Test a range of scalars (scalar multiplication could implement
    # optimizations for (-1, 0, 1)).
    scalars = [-1.432, -1, 0, 1, 3.14]
    for scalar in scalars:
        lscaled = OperatorLeftScalarMult(op, scalar)
        rscaled = OperatorRightScalarMult(op, scalar)

        assert not lscaled.is_linear
        assert not rscaled.is_linear

        # Explicit
        check_call(lscaled, x, scalar * mult_sq_np(mat, xarr))
        check_call(rscaled, x, mult_sq_np(mat, scalar * xarr))

        # Using operator overloading
        check_call(scalar * op, x, scalar * mult_sq_np(mat, xarr))
        check_call(op * scalar, x, mult_sq_np(mat, scalar * xarr))

    # Fail when scaling by wrong scalar type (complex number)
    wrongscalars = [1j, [1, 2], (1, 2)]
    for wrongscalar in wrongscalars:
        with pytest.raises(TypeError):
            OperatorLeftScalarMult(op, wrongscalar)

        with pytest.raises(TypeError):
            OperatorRightScalarMult(op, wrongscalar)

        with pytest.raises(TypeError):
            op * wrongscalar

        with pytest.raises(TypeError):
            wrongscalar * op
Ejemplo n.º 2
0
def test_nonlinear_scale():
    A = np.random.rand(4, 3)
    x = np.random.rand(3)

    Aop = MultiplyAndSquareOp(A)
    xvec = Aop.domain.element(x)

    # Test a range of scalars (scalar multiplication could implement
    # optimizations for (-1, 0, 1)).
    scalars = [-1.432, -1, 0, 1, 3.14]
    for scale in scalars:
        lscaled = OperatorLeftScalarMult(Aop, scale)
        rscaled = OperatorRightScalarMult(Aop, scale)

        assert not lscaled.is_linear
        assert not rscaled.is_linear

        # Explicit
        check_call(lscaled, xvec, scale * mult_sq_np(A, x))
        check_call(rscaled, xvec, mult_sq_np(A, scale * x))

        # Using operator overloading
        check_call(scale * Aop, xvec, scale * mult_sq_np(A, x))
        check_call(Aop * scale, xvec, mult_sq_np(A, scale * x))

    # Fail when scaling by wrong scalar type (A complex number)
    wrongscalars = [1j, [1, 2], (1, 2)]
    for wrongscalar in wrongscalars:
        with pytest.raises(TypeError):
            print(OperatorLeftScalarMult(Aop, wrongscalar))

        with pytest.raises(TypeError):
            print(OperatorRightScalarMult(Aop, wrongscalar))

        with pytest.raises(TypeError):
            print(Aop * wrongscalar)

        with pytest.raises(TypeError):
            print(wrongscalar * Aop)