def conjugate_gradient(self, A, y, x0, solve_params=LinearSolve(), callback=None): bs_y = self.staticshape(y)[0] bs_x0 = self.staticshape(x0)[0] batch_size = combined_dim(bs_y, bs_x0) if isinstance(A, (tuple, list)) or self.ndims(A) == 3: batch_size = combined_dim(batch_size, self.staticshape(A)[0]) results = [] for batch in range(batch_size): y_ = y[min(batch, bs_y - 1)] x0_ = x0[min(batch, bs_x0 - 1)] x, ret_val = cg(A, y_, x0_, tol=solve_params.relative_tolerance, atol=solve_params.absolute_tolerance, maxiter=solve_params.max_iterations) results.append(x) solve_params.result = SolveResult(success=True, iterations=-1) return self.stack(results)
def cg_fixed_point_solve(linear_op, bvec, init_x, max_iter=1000, tol=1e-10): sol, _ = linalg.cg( lambda x: tree_util.tree_multimap(jax.lax.sub, x, linear_op(x)), bvec, init_x, tol=tol, maxiter=max_iter, ) return sol
def _jax_cg_solve(x0, mat_vec, oks, grad, diag_shift, sparse_tol, sparse_maxiter): r""" Solves the SR flow equation using the conjugate gradient method """ _mat_vec = partial(mat_vec, oks=oks, diag_shift=diag_shift) out, _ = cg(_mat_vec, grad, x0=x0, tol=sparse_tol, maxiter=sparse_maxiter) return out
def _jax_cg_solve_onthefly( x0, forward_fn, params, samples, grad, diag_shift, sparse_tol, sparse_maxiter, ): _mat_vec = partial( _mat_vec_onthefly, forward_fn=forward_fn, params=params, samples=samples, diag_shift=diag_shift, ) out, _ = cg(_mat_vec, grad, x0=x0, tol=sparse_tol, maxiter=sparse_maxiter) return out
def test_cg(): # also tests if matvec can be jitted and be differentiated with AD diag_shift = 0.001 sparse_tol = 1.0e-5 sparse_maxiter = None actual = _jax_cg_solve_onthefly(v, f, params, samples, grad, diag_shift, sparse_tol, sparse_maxiter) def mv_real(v): return S_real @ v + diag_shift * v expected = reassemble_complex( cg( mv_real, grad_real_flat, x0=v_real_flat, tol=sparse_tol, maxiter=sparse_maxiter, )[0]) assert tree_allclose(actual, expected)
def updates(prev_state, hessian_xy=None, hessian_yx=None, grad_min=None, grad_max=None, breg_min=default_breg, breg_max=default_breg, eta_min=1., eta_max=1., objective_func=None, precond_b_min=False, precond_b_max=False): """Equation (4). Given current position (prev_state), compute the updates (del_x,del_y) to the players in cmd algorithm for next position. Args: prev_state (Named tuples of vectors): The current position of the players given by tuple with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)' breg_min (Named tuples of callable): Tuple of unary callables with signature 'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])' where DP and DP_inv are unary callables with signatures `DP(x,*args, **kwargs)`, 'DP_inv(x,*arg,**kwarg)' and D2P, D2P_inv are function of functions (Given an x, returning linear transformation function that can take in another vector to output hessian-vector product). breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'. eta_min (scalar): User specified step size for min player. Default 1e-4. eta_max (scalar): User specified step size for max player. Default 1e-4. hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp grad_min (vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position. grad_max(vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position. Returns: UpdateState(del_min, del_max), a named tuple for the updates """ if objective_func is not None: # grad_min_func = jit(jacfwd(objective_func, 0)) # grad_max_func = jit(jacfwd(objective_func, 1)) # H_xy_func = jit(jacfwd(grad_min, 1)) # H_yx_func =jit(jacfwd(grad_max, 0)) # Compute current gradient for min and max players grad_min = jacfwd(objective_func, 0)(prev_state.minPlayer, prev_state.maxPlayer) grad_max = jacfwd(objective_func, 1)(prev_state.minPlayer, prev_state.maxPlayer) # Define the mixed hessian-vector product linear operator at current position def hessian_xy(tangent): return make_mixed_jvp(objective_func, prev_state.minPlayer, prev_state.maxPlayer)(tangent) def hessian_yx(tangent): return make_mixed_jvp(objective_func, prev_state.minPlayer, prev_state.maxPlayer, True)(tangent) def linear_opt_min(min_tree): temp = hessian_yx(min_tree) # returns max_tree type temp1 = _tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer), temp) # returns max_tree type temp2 = hessian_xy(temp1) # returns min_tree type temp3 = tree_util.tree_map(lambda x: eta_max * x, temp2) # still min_tree type temp4 = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer), min_tree) # also returns min_tree type temp5 = tree_util.tree_map(lambda x: 1 / eta_min * x, temp4) # print("linear operator being called! - min") out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5) return out # min_tree type def linear_opt_max(max_tree): temp = hessian_xy(max_tree) temp1 = _tree_apply( _tree_apply(breg_min.inv_D2P, prev_state.minPlayer), temp) temp2 = hessian_yx(temp1) # returns max_tree type temp3 = tree_util.tree_map(lambda x: eta_min * x, temp2) # max_tree type temp4 = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer), max_tree) temp5 = tree_util.tree_map(lambda x: 1 / eta_max * x, temp4) # max_tree type # print("linear operator being called! - max") out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5) # max_tree type return out # calculate the vectors in equation (4) temp = hessian_xy( _tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer), grad_max)) temp2 = tree_util.tree_map(lambda x: eta_max * x, temp) vec_min = tree_util.tree_multimap(lambda arr1, arr2: arr1 + arr2, grad_min, temp2) if precond_b_min: # vec_min_tree, min_tree_def = tree_util.tree_flatten(vec_min) # cond_min = tree_util.tree_unflatten(min_tree_def, # jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf), vec_min_tree)) # vec_min = tree_util.tree_multimap(lambda x, y: x / y, vec_min, cond_min) cond_min = max( jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf), tree_util.tree_flatten(vec_min)[0])) vec_min = tree_util.tree_map(lambda x: x / cond_min, vec_min) # temp = _tree_apply(hessian_yx, _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), grad_min)) temp = hessian_yx( _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), grad_min)) temp2 = tree_util.tree_map(lambda x: eta_min * x, temp) vec_max = tree_util.tree_multimap(lambda x, y: x - y, grad_max, temp2) if precond_b_max: # vec_max_tree, max_tree_def = tree_util.tree_flatten(vec_max) # cond_max = tree_util.tree_unflatten(max_tree_def, # jax.tree_map(lambda x: jnp.linalg.norm(x), vec_max_tree)) # vec_max = tree_util.tree_multimap(lambda x, y: x / y, vec_max,cond_max) cond_max = max( jax.tree_map(lambda x: jnp.linalg.norm(x), tree_util.tree_flatten(vec_max)[0])) vec_max = tree_util.tree_map(lambda x: x / cond_max, vec_max) update_min, status_min = linalg.cg(linear_opt_min, vec_min, maxiter=1000) if precond_b_min: update_min = tree_util.tree_map(lambda x: cond_min * x, update_min) # update_min = tree_util.tree_multimap(lambda x, y: y * x, cond_min, update_min) update_min = tree_util.tree_map(lambda x: -x, update_min) # negation here! update_max, status_max = linalg.cg(linear_opt_max, vec_max, maxiter=1000) if precond_b_max: update_max = tree_util.tree_map(lambda x: cond_max * x, update_max) # update_max = tree_util.tree_multimap(lambda x, y: y * x, cond_max, update_max) return UpdateState(update_min, update_max)