def test_gradient(): jax_grads = jax.grad(hh, (0, 1, 2))(*inputs) fdm_grad0 = fdm.gradient(hh0)(inputs[0]) fdm_grad1 = fdm.gradient(hh1)(inputs[1]) fdm_grad2 = fdm.gradient(hh2)(inputs[2]) with check: assert np.allclose(fdm_grad0, jax_grads[0]) with check: assert np.allclose(fdm_grad1, jax_grads[1]) with check: assert np.allclose(fdm_grad2, jax_grads[2])
def test_grad(): fdm_grad = fdm.gradient(h_scalar)(x_input) jax_grad = jax.grad(h_scalar)(x_input) with check: assert np.allclose(fdm_grad, jax_grad) jax_grads = jax.grad(f_scalar, (0, 1))(x_input, y_input) fdm_grad0 = fdm_grad fdm_grad1 = fdm.gradient(lambda y: f_scalar(x_input, y))( y_input) # noqa: E731 with check: assert np.allclose(fdm_grad0, jax_grads[0]) with check: assert np.allclose(fdm_grad1, jax_grads[1])
def test_jvp_directional(): m = central_fdm(10, 1) a = np.random.randn(3) def f(x): return np.sum(a * x) x = np.random.randn(3) v = np.random.randn(3) allclose(np.sum(gradient(f, m)(x) * v), jvp(f, v, m)(x))
def check_grad(f, args, kw_args=None, rtol=1e-8): """Check the gradients of a function. Args: f (function): Function to check gradients of. args (tuple): Arguments to check `f` at. kw_args (tuple, optional): Keyword arguments to check `f` at. Defaults to no keyword arguments. rtol (float, optional): Relative tolerance. Defaults to `1e-8`. """ # Default to no keyword arguments. if kw_args is None: kw_args = {} # Get the associated function in LAB. lab_f = getattr(B, f.__name__) def create_f_i(i, args_): # Create a function that only varies the `i`th argument. def f_i(x): return B.mean( lab_f(*(args_[:i] + (x, ) + args_[i + 1:]), **kw_args)) return f_i # Walk through the arguments. for i in range(len(args)): # Numerically compute gradient. f_i = create_f_i(i, args) numerical_grad = gradient(f_i)(args[i]) # Check AutoGrad gradient. autograd_grad = grad(f_i)(args[i]) approx(numerical_grad, autograd_grad, rtol=rtol) # Check TensorFlow gradient. tf_args = tuple([as_tf(arg) for arg in args]) f_i = tf.function(create_f_i(i, tf_args), autograph=False) with tf.GradientTape() as t: t.watch(tf_args[i]) tf_grad = t.gradient(f_i(tf_args[i]), tf_args[i]).numpy() approx(numerical_grad, tf_grad, rtol=rtol) # Check PyTorch gradient. torch_args = tuple([as_torch(arg, grad=False) for arg in args]) f_i = torch.jit.trace(create_f_i(i, torch_args), torch_args[i]) arg = torch_args[i].requires_grad_(True) f_i(arg).backward() approx(numerical_grad, arg.grad, rtol=rtol) # Check JAX gradient. jax_args = tuple([jnp.asarray(arg) for arg in args]) f_i = create_f_i(i, jax_args) jax_grad = jax.jit(jax.grad(f_i))(jax_args[i]) approx(numerical_grad, jax_grad, rtol=rtol)
def test_grad(): fdm_grads = [] for func, inp in zip((fs0, fs1, fs2), inputs): fdm_grad = fdm.gradient(func)(inp) fdm_grads.append(fdm_grad) jax_grad = jax.grad(func)(inp) with check: assert np.allclose(fdm_grad, jax_grad) jax_grads = jax.grad(f_scalar, (0, 1, 2))(*inputs) for i, fdm_grad in enumerate(fdm_grads): with check: assert np.allclose(fdm_grad, jax_grads[i])
def test_gradient_vector_argument(): m = central_fdm(10, 1) for a, x in zip( [np.random.randn(), np.random.randn(3), np.random.randn(3, 3)], [np.random.randn(), np.random.randn(3), np.random.randn(3, 3)]): def f(y): return np.sum(a * y * y) allclose(2 * a * x, gradient(f, m)(x))