def test_toeplitz_solve(check_lazy_shapes): check_sensitivity(toeplitz_solve, s_toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3))) check_sensitivity(toeplitz_solve, s_toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3, 4))) check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3))) check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3, 4)))
def test_bvn_cdf(check_lazy_shapes): check_sensitivity(bvn_cdf, s_bvn_cdf, (B.rand(3), B.rand(3), B.rand(3))) check_grad(bvn_cdf, (B.rand(3), B.rand(3), B.rand(3))) # Check that function runs on both `float32`s and `float64`s. a, b, c = B.rand(3), B.rand(3), B.rand(3) approx( B.bvn_cdf(a, b, c), B.bvn_cdf(B.cast(np.float32, a), B.cast(np.float32, b), B.cast(np.float32, c)), ) # Check that, in JAX, the function check the shape of the inputs. with pytest.raises(ValueError): B.bvn_cdf(B.rand(jnp.float32, 2), B.rand(jnp.float32, 3), B.rand(jnp.float32, 3)) with pytest.raises(ValueError): B.bvn_cdf(B.rand(jnp.float32, 3), B.rand(jnp.float32, 2), B.rand(jnp.float32, 3)) with pytest.raises(ValueError): B.bvn_cdf(B.rand(jnp.float32, 3), B.rand(jnp.float32, 3), B.rand(jnp.float32, 2))
def test_sensitivity(): # Test two-argument case. def f(a, b): return a * b def s_f_correct(s_y, y, a, b): return s_y * b, a * s_y def s_f_incorrect1(s_y, y, a, b): return s_y * b + 1, a * s_y def s_f_incorrect2(s_y, y, a, b): return s_y * b, a * s_y * 1.1 check_sensitivity(f, s_f_correct, (1, 2)) with pytest.raises(AssertionError): check_sensitivity(f, s_f_incorrect1, (2, 3)) with pytest.raises(AssertionError): check_sensitivity(f, s_f_incorrect2, (4, 5)) # Test one-argument case with a keyword argument. def g(a, option=False): assert option return 2 * a def s_g(s_y, y, a, option=False): assert option return s_y * 2 with pytest.raises(AssertionError): check_sensitivity(g, s_g, (1,)) check_sensitivity(g, s_g, (1,), {"option": True}) # Check that the the number of sensitivities must match the number of # arguments. def s_f_too_few(s_y, y, a, b): return s_y * a def s_g_too_many(s_y, y, a, option=True): return 2, 3 with pytest.raises(AssertionError): check_sensitivity(f, s_f_too_few, (1, 2)) with pytest.raises(AssertionError): check_sensitivity(g, s_g_too_many, (1,))
def test_logm(check_lazy_shapes): mat = B.eye(3) + 0.1 * B.randn(3, 3) check_sensitivity(logm, s_logm, (mat, )) check_grad(logm, (mat, ))
def test_expm(check_lazy_shapes): check_sensitivity(expm, s_expm, (B.randn(3, 3), )) check_grad(expm, (B.randn(3, 3), ))