def test_gradient():
    jax_grads = jax.grad(hh, (0, 1, 2))(*inputs)
    fdm_grad0 = fdm.gradient(hh0)(inputs[0])
    fdm_grad1 = fdm.gradient(hh1)(inputs[1])
    fdm_grad2 = fdm.gradient(hh2)(inputs[2])

    with check:
        assert np.allclose(fdm_grad0, jax_grads[0])
    with check:
        assert np.allclose(fdm_grad1, jax_grads[1])
    with check:
        assert np.allclose(fdm_grad2, jax_grads[2])
Example #2
0
def test_grad():
    fdm_grad = fdm.gradient(h_scalar)(x_input)
    jax_grad = jax.grad(h_scalar)(x_input)

    with check:
        assert np.allclose(fdm_grad, jax_grad)

    jax_grads = jax.grad(f_scalar, (0, 1))(x_input, y_input)
    fdm_grad0 = fdm_grad
    fdm_grad1 = fdm.gradient(lambda y: f_scalar(x_input, y))(
        y_input)  # noqa: E731

    with check:
        assert np.allclose(fdm_grad0, jax_grads[0])
    with check:
        assert np.allclose(fdm_grad1, jax_grads[1])
def test_jvp_directional():
    m = central_fdm(10, 1)
    a = np.random.randn(3)

    def f(x):
        return np.sum(a * x)

    x = np.random.randn(3)
    v = np.random.randn(3)
    allclose(np.sum(gradient(f, m)(x) * v), jvp(f, v, m)(x))
Example #4
0
def check_grad(f, args, kw_args=None, rtol=1e-8):
    """Check the gradients of a function.

    Args:
        f (function): Function to check gradients of.
        args (tuple): Arguments to check `f` at.
        kw_args (tuple, optional): Keyword arguments to check `f` at. Defaults
            to no keyword arguments.
        rtol (float, optional): Relative tolerance. Defaults to `1e-8`.
    """
    # Default to no keyword arguments.
    if kw_args is None:
        kw_args = {}

    # Get the associated function in LAB.
    lab_f = getattr(B, f.__name__)

    def create_f_i(i, args_):
        # Create a function that only varies the `i`th argument.
        def f_i(x):
            return B.mean(
                lab_f(*(args_[:i] + (x, ) + args_[i + 1:]), **kw_args))

        return f_i

    # Walk through the arguments.
    for i in range(len(args)):
        # Numerically compute gradient.
        f_i = create_f_i(i, args)
        numerical_grad = gradient(f_i)(args[i])

        # Check AutoGrad gradient.
        autograd_grad = grad(f_i)(args[i])
        approx(numerical_grad, autograd_grad, rtol=rtol)

        # Check TensorFlow gradient.
        tf_args = tuple([as_tf(arg) for arg in args])
        f_i = tf.function(create_f_i(i, tf_args), autograph=False)
        with tf.GradientTape() as t:
            t.watch(tf_args[i])
            tf_grad = t.gradient(f_i(tf_args[i]), tf_args[i]).numpy()
        approx(numerical_grad, tf_grad, rtol=rtol)

        # Check PyTorch gradient.
        torch_args = tuple([as_torch(arg, grad=False) for arg in args])
        f_i = torch.jit.trace(create_f_i(i, torch_args), torch_args[i])
        arg = torch_args[i].requires_grad_(True)
        f_i(arg).backward()
        approx(numerical_grad, arg.grad, rtol=rtol)

        # Check JAX gradient.
        jax_args = tuple([jnp.asarray(arg) for arg in args])
        f_i = create_f_i(i, jax_args)
        jax_grad = jax.jit(jax.grad(f_i))(jax_args[i])
        approx(numerical_grad, jax_grad, rtol=rtol)
Example #5
0
def test_grad():
    fdm_grads = []
    for func, inp in zip((fs0, fs1, fs2), inputs):
        fdm_grad = fdm.gradient(func)(inp)
        fdm_grads.append(fdm_grad)
        jax_grad = jax.grad(func)(inp)

        with check:
            assert np.allclose(fdm_grad, jax_grad)

    jax_grads = jax.grad(f_scalar, (0, 1, 2))(*inputs)
    for i, fdm_grad in enumerate(fdm_grads):
        with check:
            assert np.allclose(fdm_grad, jax_grads[i])
def test_gradient_vector_argument():
    m = central_fdm(10, 1)

    for a, x in zip(
        [np.random.randn(),
         np.random.randn(3),
         np.random.randn(3, 3)],
        [np.random.randn(),
         np.random.randn(3),
         np.random.randn(3, 3)]):

        def f(y):
            return np.sum(a * y * y)

        allclose(2 * a * x, gradient(f, m)(x))