示例#1
0
    def _adjoint_iteration_vjp(g, ans, init_xs, params):
        dvalue = g
        del init_xs
        init_dxs = dvalue

        fp_vjp_fn = jax.vjp(partial(flat_func, ans.iterations), ans.value,
                            params)[1]

        def dfp_fn(i, dout):
            del i
            dout = fp_vjp_fn(dout)[0] + dvalue
            return dout

        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   init_dxs.dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        dsol = loop.fixed_point_iteration(
            init_x=init_dxs,
            func=dfp_fn,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
        )

        return fp_vjp_fn(dsol.value)[1], dsol
示例#2
0
    def _default_solver(init_x, params):
        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   init_x.dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        func = param_func(params)
        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=func,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
        )

        return sol
示例#3
0
    def fixed_point_iteration_solver(init_x, params):
        dtype = converge.tree_smallest_float_dtype(init_x)
        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        func = param_func(params)
        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=func,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
            unroll=unroll,
        )

        return sol.value
示例#4
0
def default_convergence_test(rtol=1e-10, atol=1e-10, dtype=np.float32):
    """ Create a simple convergence test with tolerances adjusted for dtype.

    Args:
        rtol (float, optional): The relative tolerance for convergence.
        atol (float, optional): The absolute tolerance for convergence.
        dtype (optional): The dtype used to adjust the required tolerance such
            that it is within what is achievable with `dtype`.

    Returns:
        A callable taking in the output of the current and last iteration and
        returns a boolean value indicating whether convergence is achieved.
    """
    adjusted_tol = converge.adjust_tol_for_dtype(rtol, atol, dtype)

    def convergence_test(x_new, x_old):
        return converge.max_diff_test(x_new, x_old, *adjusted_tol)

    return convergence_test
示例#5
0
文件: cg.py 项目: niklasschmitz/fax
def conjugate_gradient_solve(linear_op,
                             bvec,
                             init_x,
                             max_iter=1000,
                             atol=1e-10):
    dtype = converge.tree_smallest_float_dtype(bvec)
    _, atol = converge.adjust_tol_for_dtype(0., atol=atol, dtype=dtype)
    init_r = tree_util.tree_multimap(lax.sub, bvec, linear_op(init_x))
    init_p = init_r
    init_r_sqr = math.pytree_dot(init_r, init_r)

    def convergence_test(state_new, state_old):
        del state_old
        return state_new[2] < atol

    solution = loop.fixed_point_iteration((init_x, init_r, init_r_sqr, init_p),
                                          func=partial(cg_step, linear_op),
                                          convergence_test=convergence_test,
                                          max_iter=max_iter)
    return solution._replace(
        value=solution.value[0],
        previous_value=solution.value[0],
    )
示例#6
0
文件: cga.py 项目: niklasschmitz/fax
 def default_convergence_test(x_new, x_old):
     min_type = converge.tree_smallest_float_dtype(x_new)
     rtol, atol = converge.adjust_tol_for_dtype(1e-10, 1e-10, min_type)
     return converge.max_diff_test(x_new, x_old, rtol, atol)