Exemplo n.º 1
0
    def conjugate_gradient(self,
                           A,
                           y,
                           x0,
                           solve_params=LinearSolve(),
                           callback=None):
        bs_y = self.staticshape(y)[0]
        bs_x0 = self.staticshape(x0)[0]
        batch_size = combined_dim(bs_y, bs_x0)

        if isinstance(A, (tuple, list)) or self.ndims(A) == 3:
            batch_size = combined_dim(batch_size, self.staticshape(A)[0])

        results = []

        for batch in range(batch_size):
            y_ = y[min(batch, bs_y - 1)]
            x0_ = x0[min(batch, bs_x0 - 1)]
            x, ret_val = cg(A,
                            y_,
                            x0_,
                            tol=solve_params.relative_tolerance,
                            atol=solve_params.absolute_tolerance,
                            maxiter=solve_params.max_iterations)

            results.append(x)
        solve_params.result = SolveResult(success=True, iterations=-1)
        return self.stack(results)
Exemplo n.º 2
0
def cg_fixed_point_solve(linear_op, bvec, init_x, max_iter=1000, tol=1e-10):
    sol, _ = linalg.cg(
        lambda x: tree_util.tree_multimap(jax.lax.sub, x, linear_op(x)),
        bvec,
        init_x,
        tol=tol,
        maxiter=max_iter,
    )
    return sol
def _jax_cg_solve(x0, mat_vec, oks, grad, diag_shift, sparse_tol, sparse_maxiter):
    r"""
    Solves the SR flow equation using the conjugate gradient method
    """

    _mat_vec = partial(mat_vec, oks=oks, diag_shift=diag_shift)

    out, _ = cg(_mat_vec, grad, x0=x0, tol=sparse_tol, maxiter=sparse_maxiter)

    return out
def _jax_cg_solve_onthefly(
    x0,
    forward_fn,
    params,
    samples,
    grad,
    diag_shift,
    sparse_tol,
    sparse_maxiter,
):
    _mat_vec = partial(
        _mat_vec_onthefly,
        forward_fn=forward_fn,
        params=params,
        samples=samples,
        diag_shift=diag_shift,
    )
    out, _ = cg(_mat_vec, grad, x0=x0, tol=sparse_tol, maxiter=sparse_maxiter)
    return out
Exemplo n.º 5
0
def test_cg():
    # also tests if matvec can be jitted and be differentiated with AD
    diag_shift = 0.001
    sparse_tol = 1.0e-5
    sparse_maxiter = None
    actual = _jax_cg_solve_onthefly(v, f, params, samples, grad, diag_shift,
                                    sparse_tol, sparse_maxiter)

    def mv_real(v):
        return S_real @ v + diag_shift * v

    expected = reassemble_complex(
        cg(
            mv_real,
            grad_real_flat,
            x0=v_real_flat,
            tol=sparse_tol,
            maxiter=sparse_maxiter,
        )[0])
    assert tree_allclose(actual, expected)
Exemplo n.º 6
0
def updates(prev_state,
            hessian_xy=None,
            hessian_yx=None,
            grad_min=None,
            grad_max=None,
            breg_min=default_breg,
            breg_max=default_breg,
            eta_min=1.,
            eta_max=1.,
            objective_func=None,
            precond_b_min=False,
            precond_b_max=False):
    """Equation (4). Given current position (prev_state), compute the updates (del_x,del_y) to the players in cmd algorithm for next position.

    Args:
        prev_state (Named tuples of vectors): The current position of the players given by tuple
                                             with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)'
        breg_min (Named tuples of callable): Tuple of unary callables with signature
                                            'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])'
                                            where DP and DP_inv are unary callables with signatures
                                            `DP(x,*args, **kwargs)`, 'DP_inv(x,*arg,**kwarg)' and
                                            D2P, D2P_inv are function of functions
                                            (Given an x, returning linear transformation function
                                            that can take in another vector to output hessian-vector product).
        breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'.
        eta_min (scalar): User specified step size for min player. Default 1e-4.
        eta_max (scalar): User specified step size for max player. Default 1e-4.
        hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp
        hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp
        grad_min (vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position.
        grad_max(vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position.
    Returns:
        UpdateState(del_min, del_max), a named tuple for the updates
    """
    if objective_func is not None:
        # grad_min_func = jit(jacfwd(objective_func, 0))
        # grad_max_func = jit(jacfwd(objective_func, 1))
        # H_xy_func = jit(jacfwd(grad_min, 1))
        # H_yx_func =jit(jacfwd(grad_max, 0))

        # Compute current gradient for min and max players
        grad_min = jacfwd(objective_func, 0)(prev_state.minPlayer,
                                             prev_state.maxPlayer)
        grad_max = jacfwd(objective_func, 1)(prev_state.minPlayer,
                                             prev_state.maxPlayer)

        # Define the mixed hessian-vector product linear operator at current position
        def hessian_xy(tangent):
            return make_mixed_jvp(objective_func, prev_state.minPlayer,
                                  prev_state.maxPlayer)(tangent)

        def hessian_yx(tangent):
            return make_mixed_jvp(objective_func, prev_state.minPlayer,
                                  prev_state.maxPlayer, True)(tangent)

    def linear_opt_min(min_tree):
        temp = hessian_yx(min_tree)  # returns max_tree type
        temp1 = _tree_apply(_tree_apply(breg_max.inv_D2P,
                                        prev_state.maxPlayer),
                            temp)  # returns max_tree type
        temp2 = hessian_xy(temp1)  # returns min_tree type
        temp3 = tree_util.tree_map(lambda x: eta_max * x,
                                   temp2)  # still min_tree type
        temp4 = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer),
                            min_tree)  # also returns min_tree type
        temp5 = tree_util.tree_map(lambda x: 1 / eta_min * x, temp4)
        # print("linear operator being called! - min")
        out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5)
        return out  # min_tree type

    def linear_opt_max(max_tree):
        temp = hessian_xy(max_tree)
        temp1 = _tree_apply(
            _tree_apply(breg_min.inv_D2P, prev_state.minPlayer), temp)
        temp2 = hessian_yx(temp1)  # returns max_tree type
        temp3 = tree_util.tree_map(lambda x: eta_min * x,
                                   temp2)  # max_tree type
        temp4 = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer),
                            max_tree)
        temp5 = tree_util.tree_map(lambda x: 1 / eta_max * x,
                                   temp4)  # max_tree type
        # print("linear operator being called! - max")
        out = tree_util.tree_multimap(lambda x, y: x + y, temp3,
                                      temp5)  # max_tree type
        return out

    # calculate the vectors in equation (4)
    temp = hessian_xy(
        _tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer),
                    grad_max))
    temp2 = tree_util.tree_map(lambda x: eta_max * x, temp)
    vec_min = tree_util.tree_multimap(lambda arr1, arr2: arr1 + arr2, grad_min,
                                      temp2)

    if precond_b_min:
        # vec_min_tree, min_tree_def = tree_util.tree_flatten(vec_min)
        # cond_min = tree_util.tree_unflatten(min_tree_def,
        #                                     jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf), vec_min_tree))
        # vec_min = tree_util.tree_multimap(lambda x, y: x / y, vec_min, cond_min)
        cond_min = max(
            jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf),
                         tree_util.tree_flatten(vec_min)[0]))
        vec_min = tree_util.tree_map(lambda x: x / cond_min, vec_min)

    # temp = _tree_apply(hessian_yx, _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), grad_min))
    temp = hessian_yx(
        _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer),
                    grad_min))
    temp2 = tree_util.tree_map(lambda x: eta_min * x, temp)
    vec_max = tree_util.tree_multimap(lambda x, y: x - y, grad_max, temp2)

    if precond_b_max:
        # vec_max_tree, max_tree_def = tree_util.tree_flatten(vec_max)
        # cond_max = tree_util.tree_unflatten(max_tree_def,
        #                                     jax.tree_map(lambda x: jnp.linalg.norm(x), vec_max_tree))
        # vec_max = tree_util.tree_multimap(lambda x, y: x / y, vec_max,cond_max)
        cond_max = max(
            jax.tree_map(lambda x: jnp.linalg.norm(x),
                         tree_util.tree_flatten(vec_max)[0]))
        vec_max = tree_util.tree_map(lambda x: x / cond_max, vec_max)

    update_min, status_min = linalg.cg(linear_opt_min, vec_min, maxiter=1000)

    if precond_b_min:
        update_min = tree_util.tree_map(lambda x: cond_min * x, update_min)
        # update_min = tree_util.tree_multimap(lambda x, y: y * x, cond_min, update_min)

    update_min = tree_util.tree_map(lambda x: -x, update_min)  # negation here!

    update_max, status_max = linalg.cg(linear_opt_max, vec_max, maxiter=1000)
    if precond_b_max:
        update_max = tree_util.tree_map(lambda x: cond_max * x, update_max)
        # update_max = tree_util.tree_multimap(lambda x, y: y * x, cond_max, update_max)

    return UpdateState(update_min, update_max)