示例#1
0
 def testCholeskyGradPrecision(self):
     rng = jtu.rand_default()
     a = rng((3, 3), onp.float32)
     a = onp.dot(a, a.T)
     jtu.assert_dot_precision(lax.Precision.HIGHEST,
                              partial(jvp, np.linalg.cholesky), (a, ),
                              (a, ))
示例#2
0
 def testTriangularSolveGradPrecision(self):
     rng = jtu.rand_default()
     a = np.tril(rng((3, 3), onp.float32))
     b = rng((1, 3), onp.float32)
     jtu.assert_dot_precision(lax.Precision.HIGHEST,
                              partial(jvp, lax_linalg.triangular_solve),
                              (a, b), (a, b))
示例#3
0
 def testEighGradPrecision(self):
   rng = jtu.rand_default()
   a = rng((3, 3), onp.float32)
   jtu.assert_dot_precision(
       lax.Precision.HIGHEST, partial(jvp, np.linalg.eigh), (a,), (a,))