コード例 #1
0
  def test_custom_root_errors(self):
    with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
      lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
    with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
      lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)

    def dummy_root_usage(x):
      f = lambda y: x - y
      return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))

    with self.assertRaisesRegex(
        TypeError, re.escape("tangent_solve() output pytree")):
      api.jvp(dummy_root_usage, (0.0,), (0.0,))
コード例 #2
0
 def linear_solve(a, b):
   f = lambda x: np.dot(a, x) - b
   factors = jsp.linalg.cho_factor(a)
   cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
   def pos_def_solve(g, b):
     return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
   return lax.custom_root(f, b, cho_solve, pos_def_solve)
コード例 #3
0
        def root_aux(a, b):
            f = lambda x: high_precision_dot(a, x) - b
            factors = jsp.linalg.cho_factor(a)
            cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b),
                                      orig_aux)

            def pos_def_solve(g, b):
                # prune aux to allow use as tangent_solve
                cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
                return lax.custom_linear_solve(g,
                                               b,
                                               cho_solve_noaux,
                                               symmetric=True)

            return lax.custom_root(f,
                                   b,
                                   cho_solve,
                                   pos_def_solve,
                                   has_aux=True)
コード例 #4
0
 def dummy_root_usage(x):
   f = lambda y: x - y
   return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
コード例 #5
0
 def linear_solve(a, b):
   f = lambda y: high_precision_dot(a, y) - b
   x0 = np.zeros_like(b)
   solution = np.linalg.solve(a, b)
   oracle = lambda func, x0: solution
   return lax.custom_root(f, x0, oracle, vector_solve)
コード例 #6
0
 def sqrt_cubed(x, tangent_solve=scalar_solve):
   f = lambda y: y ** 2. - np.array(x) ** 3.
   return lax.custom_root(f, 0.0, binary_search, tangent_solve)
コード例 #7
0
ファイル: errors_test.py プロジェクト: uafpdivad/jax
 def f3():
     return lax.custom_root(err, 0., solve, solve)
コード例 #8
0
ファイル: errors_test.py プロジェクト: uafpdivad/jax
 def f2():
     return lax.custom_root(g, 0., solve, err)
コード例 #9
0
ファイル: errors_test.py プロジェクト: uafpdivad/jax
 def f1():
     return lax.custom_root(g, 0., err, solve)
コード例 #10
0
 def sqrt_cubed(x, tangent_solve=scalar_solve):
     f = lambda y: y**2 - x**3
     return lax.custom_root(f, 0.0, binary_search, tangent_solve)
コード例 #11
0
 def sqrt_cubed(x, tangent_solve=scalar_solve):
     f = lambda y: y**2 - x**3
     # Note: Nonzero derivative at x0 required for newton_raphson
     return lax.custom_root(f, 1.0, solve_method, tangent_solve)
コード例 #12
0
 def nonlinear_solve(y):
     f = lambda x: nonlinear_func(x, y)
     x0 = -jnp.ones_like(y)
     return lax.custom_root(f, x0, newton_raphson, tangent_solve)