Ejemplo n.º 1
0
 def testGammaGradType(self):
   # Regression test for https://github.com/google/jax/issues/2130
   key = random.PRNGKey(0)
   a = jnp.array(1., dtype=jnp.float32)
   b = jnp.array(3., dtype=jnp.float32)
   f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
   # Should not crash with a type error.
   api.vjp(f, a, b)
Ejemplo n.º 2
0
 def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                         dimension_numbers):
   rng = jtu.rand_small(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                         precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
   # check that precision config is preserved
   result, pullback = api.vjp(dot_general, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Ejemplo n.º 3
0
 def testDotGrad(self, lhs_shape, rhs_shape, dtype):
   rng = jtu.rand_default(self.rng())
   tol = {np.float16: 1e-1, np.float32: 1e-4}
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                        atol=tol, rtol=tol)
   # check that precision config is preserved
   result, pullback = api.vjp(dot, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Ejemplo n.º 4
0
 def f_vjp(*args):
   out_primal_py, vjp_py = api.vjp(f, *args)
   return vjp_py(out_primal_py)