def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params): vals, rows, cols, mat = primals_in sparse_mat_dot, rows_dot, cols_dot, mat_dot = tangents_in assert type(rows_dot) is ad.Zero assert type(cols_dot) is ad.Zero primals_out = coo_matmat(vals, rows, cols, mat, **params) _zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t _sparse_mat_dot = _zero(vals, sparse_mat_dot) _mat_dot = _zero(mat, mat_dot) tangents_out = coo_matmat(_sparse_mat_dot, rows, cols, mat, **params) + coo_matmat(vals, rows, cols, _mat_dot, **params) return primals_out, tangents_out
def make_zero(x, t): return lax.zeros_like_array(x) if type(t) is ad.Zero else t
def create_token_value_and_jvp(in_args, tan_args): (x, ) = in_args res = create_token(x) jvp = zeros_like_array(x) return (res, jvp)
def zero_tangent(tan, val): return lax.zeros_like_array(val) if type(tan) is ad.Zero else tan