示例#1
0
 def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                         dimension_numbers, rng_factory):
   rng = rng_factory(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                         precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
   # check that precision config is preserved
   result, pullback = api.vjp(dot_general, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
示例#2
0
 def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
   rng = rng_factory(self.rng())
   tol = {onp.float16: 1e-1, onp.float32: 1e-4}
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                        atol=tol, rtol=tol)
   # check that precision config is preserved
   result, pullback = api.vjp(dot, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
示例#3
0
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params):
  vals, rows, cols, mat = primals_in
  sparse_mat_dot, rows_dot, cols_dot, mat_dot = tangents_in
  assert type(rows_dot) is ad.Zero
  assert type(cols_dot) is ad.Zero

  primals_out = coo_matmat(vals, rows, cols, mat, **params)
  _zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t
  _sparse_mat_dot = _zero(vals, sparse_mat_dot)
  _mat_dot = _zero(mat, mat_dot)

  tangents_out = coo_matmat(_sparse_mat_dot, rows, cols, mat, **params) + coo_matmat(vals, rows, cols, _mat_dot, **params)
  return primals_out, tangents_out
示例#4
0
文件: ops.py 项目: zeta1999/celerite2
 def make_zero(x, t):
     return lax.zeros_like_array(x) if type(t) is ad.Zero else t
示例#5
0
def create_token_value_and_jvp(in_args, tan_args):
    (x, ) = in_args
    res = create_token(x)
    jvp = zeros_like_array(x)
    return (res, jvp)
示例#6
0
 def zero_tangent(tan, val):
     return lax.zeros_like_array(val) if type(tan) is ad.Zero else tan