def _adjoint_iteration_vjp(g, ans, init_xs, params): dvalue = g del init_xs init_dxs = dvalue fp_vjp_fn = jax.vjp(partial(flat_func, ans.iterations), ans.value, params)[1] def dfp_fn(i, dout): del i dout = fp_vjp_fn(dout)[0] + dvalue return dout rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, init_dxs.dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) dsol = loop.fixed_point_iteration( init_x=init_dxs, func=dfp_fn, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, ) return fp_vjp_fn(dsol.value)[1], dsol
def _default_solver(init_x, params): rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, init_x.dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) func = param_func(params) sol = loop.fixed_point_iteration( init_x=init_x, func=func, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, ) return sol
def fixed_point_iteration_solver(init_x, params): dtype = converge.tree_smallest_float_dtype(init_x) rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) func = param_func(params) sol = loop.fixed_point_iteration( init_x=init_x, func=func, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, unroll=unroll, ) return sol.value
def default_convergence_test(rtol=1e-10, atol=1e-10, dtype=np.float32): """ Create a simple convergence test with tolerances adjusted for dtype. Args: rtol (float, optional): The relative tolerance for convergence. atol (float, optional): The absolute tolerance for convergence. dtype (optional): The dtype used to adjust the required tolerance such that it is within what is achievable with `dtype`. Returns: A callable taking in the output of the current and last iteration and returns a boolean value indicating whether convergence is achieved. """ adjusted_tol = converge.adjust_tol_for_dtype(rtol, atol, dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, *adjusted_tol) return convergence_test
def conjugate_gradient_solve(linear_op, bvec, init_x, max_iter=1000, atol=1e-10): dtype = converge.tree_smallest_float_dtype(bvec) _, atol = converge.adjust_tol_for_dtype(0., atol=atol, dtype=dtype) init_r = tree_util.tree_multimap(lax.sub, bvec, linear_op(init_x)) init_p = init_r init_r_sqr = math.pytree_dot(init_r, init_r) def convergence_test(state_new, state_old): del state_old return state_new[2] < atol solution = loop.fixed_point_iteration((init_x, init_r, init_r_sqr, init_p), func=partial(cg_step, linear_op), convergence_test=convergence_test, max_iter=max_iter) return solution._replace( value=solution.value[0], previous_value=solution.value[0], )
def default_convergence_test(x_new, x_old): min_type = converge.tree_smallest_float_dtype(x_new) rtol, atol = converge.adjust_tol_for_dtype(1e-10, 1e-10, min_type) return converge.max_diff_test(x_new, x_old, rtol, atol)