def testGammaGradType(self): # Regression test for https://github.com/google/jax/issues/2130 key = random.PRNGKey(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y # Should not crash with a type error. api.vjp(f, a, b)
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def testDotGrad(self, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) tol = {np.float16: 1e-1, np.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def f_vjp(*args): out_primal_py, vjp_py = api.vjp(f, *args) return vjp_py(out_primal_py)