Example #1
0
def test_toeplitz_solve(check_lazy_shapes):
    check_sensitivity(toeplitz_solve, s_toeplitz_solve,
                      (B.randn(3), B.randn(2), B.randn(3)))
    check_sensitivity(toeplitz_solve, s_toeplitz_solve,
                      (B.randn(3), B.randn(2), B.randn(3, 4)))
    check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3)))
    check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3, 4)))
Example #2
0
def test_bvn_cdf(check_lazy_shapes):
    check_sensitivity(bvn_cdf, s_bvn_cdf, (B.rand(3), B.rand(3), B.rand(3)))
    check_grad(bvn_cdf, (B.rand(3), B.rand(3), B.rand(3)))

    # Check that function runs on both `float32`s and `float64`s.
    a, b, c = B.rand(3), B.rand(3), B.rand(3)
    approx(
        B.bvn_cdf(a, b, c),
        B.bvn_cdf(B.cast(np.float32, a), B.cast(np.float32, b),
                  B.cast(np.float32, c)),
    )

    # Check that, in JAX, the function check the shape of the inputs.
    with pytest.raises(ValueError):
        B.bvn_cdf(B.rand(jnp.float32, 2), B.rand(jnp.float32, 3),
                  B.rand(jnp.float32, 3))
    with pytest.raises(ValueError):
        B.bvn_cdf(B.rand(jnp.float32, 3), B.rand(jnp.float32, 2),
                  B.rand(jnp.float32, 3))
    with pytest.raises(ValueError):
        B.bvn_cdf(B.rand(jnp.float32, 3), B.rand(jnp.float32, 3),
                  B.rand(jnp.float32, 2))
Example #3
0
def test_sensitivity():
    # Test two-argument case.
    def f(a, b):
        return a * b

    def s_f_correct(s_y, y, a, b):
        return s_y * b, a * s_y

    def s_f_incorrect1(s_y, y, a, b):
        return s_y * b + 1, a * s_y

    def s_f_incorrect2(s_y, y, a, b):
        return s_y * b, a * s_y * 1.1

    check_sensitivity(f, s_f_correct, (1, 2))
    with pytest.raises(AssertionError):
        check_sensitivity(f, s_f_incorrect1, (2, 3))
    with pytest.raises(AssertionError):
        check_sensitivity(f, s_f_incorrect2, (4, 5))

    # Test one-argument case with a keyword argument.
    def g(a, option=False):
        assert option
        return 2 * a

    def s_g(s_y, y, a, option=False):
        assert option
        return s_y * 2

    with pytest.raises(AssertionError):
        check_sensitivity(g, s_g, (1,))
    check_sensitivity(g, s_g, (1,), {"option": True})

    # Check that the the number of sensitivities must match the number of
    # arguments.
    def s_f_too_few(s_y, y, a, b):
        return s_y * a

    def s_g_too_many(s_y, y, a, option=True):
        return 2, 3

    with pytest.raises(AssertionError):
        check_sensitivity(f, s_f_too_few, (1, 2))
    with pytest.raises(AssertionError):
        check_sensitivity(g, s_g_too_many, (1,))
Example #4
0
def test_logm(check_lazy_shapes):
    mat = B.eye(3) + 0.1 * B.randn(3, 3)
    check_sensitivity(logm, s_logm, (mat, ))
    check_grad(logm, (mat, ))
Example #5
0
def test_expm(check_lazy_shapes):
    check_sensitivity(expm, s_expm, (B.randn(3, 3), ))
    check_grad(expm, (B.randn(3, 3), ))