def testUnrollFixedpointLoop(self): max_steps = 10 def step(x): return x - 1 init_x = np.zeros(()) unroll_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=1, unroll=True, ) loop_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=1, unroll=False, ) testing.assert_array_equal(unroll_sol, loop_sol)
def testBatchedLoop(self, unroll): max_steps = 10 def step(i, x): del i return x - 1 init_x = np.zeros(()) batched_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=5, unroll=unroll, ) loop_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=1, unroll=unroll, ) testing.assert_array_equal(batched_sol, loop_sol)
def testTermination(self, unroll): if unroll: self.skipTest( ("Can't terminate early when unrolling until `jax.lax.cond` " "supports differentiation.")) max_steps = 10 def step(x): return x - 1 init_x = np.zeros(()) term_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda x, *args: x <= -5, max_iter=max_steps, batched_iter_size=1, unroll=unroll, ) last_i, fixed_value = loop.unrolled( 0, init_x=init_x, func=step, num_iter=5, return_last_two=False, ) self.assertEqual(fixed_value, term_sol.value) self.assertEqual(last_i, term_sol.iterations)
def _adjoint_iteration_vjp(g, ans, init_xs, params): dvalue = g del init_xs init_dxs = dvalue fp_vjp_fn = jax.vjp(partial(flat_func, ans.iterations), ans.value, params)[1] def dfp_fn(i, dout): del i dout = fp_vjp_fn(dout)[0] + dvalue return dout rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, init_dxs.dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) dsol = loop.fixed_point_iteration( init_x=init_dxs, func=dfp_fn, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, ) return fp_vjp_fn(dsol.value)[1], dsol
def none_divisible_batch(): init_x = np.zeros(()) return loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=3, )
def run_unrolled(x): return loop.fixed_point_iteration( init_x=x, func=step, convergence_test=converge_test, max_iter=max_steps, batched_iter_size=1, unroll=True, ).value
def run_loop(x): return loop.fixed_point_iteration( init_x=x, func=step, convergence_test=lambda *args: False, max_iter=max_steps, batched_iter_size=1, unroll=False, )
def solve(x): return loop.fixed_point_iteration( init_x=x, func=step, convergence_test=convergence_test, max_iter=max_steps, batched_iter_size=1, unroll=unroll, )
def default_solver(linear_op, bvec, init_x=None): if init_x is None: init_x = bvec def _step_default_solver(i, x): del i return tree_util.tree_multimap(lax.add, linear_op(x), bvec) return loop.fixed_point_iteration( init_x=init_x, func=_step_default_solver, convergence_test=default_convergence_test, max_iter=default_max_iter, )
def testNoneMaxIter(self): max_steps = None def step(x): return x + 1 init_x = np.zeros(()) loop_sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=lambda i, *args: i > 10, max_iter=max_steps, batched_iter_size=3, ) self.assertIsNotNone(loop_sol) loop_sol.value.block_until_ready()
def _default_solver(init_x, params): rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, init_x.dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) func = param_func(params) sol = loop.fixed_point_iteration( init_x=init_x, func=func, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, ) return sol
def _default_solve(param_func, init_x, params): _convergence_test = convergence_test if convergence_test is None: _convergence_test = default_convergence_test( dtype=converge.tree_smallest_float_dtype(init_x), ) func = param_func(params) sol = loop.fixed_point_iteration( init_x=init_x, func=func, convergence_test=_convergence_test, max_iter=max_iter, batched_iter_size=batched_iter_size, ) return sol.value
def fixed_point_iteration_solver(init_x, params): dtype = converge.tree_smallest_float_dtype(init_x) rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol, dtype) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) func = param_func(params) sol = loop.fixed_point_iteration( init_x=init_x, func=func, convergence_test=convergence_test, max_iter=default_max_iter, batched_iter_size=default_batched_iter_size, unroll=unroll, ) return sol.value
def testFixedPointDiverge(self, unroll): rtol = atol = 1e-10 max_steps = 10 init_x = np.zeros(()) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) def step(x_old): return x_old + 1 sol = loop.fixed_point_iteration( init_x=init_x, func=step, convergence_test=convergence_test, max_iter=max_steps, batched_iter_size=1, unroll=unroll, ) self.assertFalse(sol.converged) self.assertEqual(sol.iterations, max_steps)
def conjugate_gradient_solve(linear_op, bvec, init_x, max_iter=1000, atol=1e-10): dtype = converge.tree_smallest_float_dtype(bvec) _, atol = converge.adjust_tol_for_dtype(0., atol=atol, dtype=dtype) init_r = tree_util.tree_multimap(lax.sub, bvec, linear_op(init_x)) init_p = init_r init_r_sqr = math.pytree_dot(init_r, init_r) def convergence_test(state_new, state_old): del state_old return state_new[2] < atol solution = loop.fixed_point_iteration((init_x, init_r, init_r_sqr, init_p), func=partial(cg_step, linear_op), convergence_test=convergence_test, max_iter=max_iter) return solution._replace( value=solution.value[0], previous_value=solution.value[0], )
def cga_iteration(init_values, f, g, convergence_test, max_iter, step_size_f, step_size_g=None, linear_op_solver=None, batched_iter_size=1, unroll=False, use_full_matrix=False, solve_order='both'): """Run competitive gradient ascent until convergence or some max iteration. Use this function to find a fixed point of the competitive gradient ascent (CGA) update by repeatedly applying CGA to a candidate solution. This is done until the solution converges or until the maximum number of iterations, `max_iter` is reached. NOTE: if the maximum number of iterations is reached, the convergence will not be checked on the final application of `func` and the solution will always be marked as not converged when `unroll` is `False`. Args: init_values: a tuple of type `(a, b)` corresponding to the types accepted by `f` and `g`. f (callable): The function we which to maximize with type `a, b -> float`. g (callable): The "opposing" function which is also maximized with type `a, b -> float`. convergence_test (callable): A two argument function of type `(a,b), (a, b) -> bool` that takes in the newest solution and the previous solution and returns `True` if they have converged. The optimization will stop and return when `True` is returned. max_iter (int or None): The maximum number of iterations. step_size_f: The step size used by CGA for `f`. This can be a scalar or a callable taking in the current iteration and returning a scalar. If no step size is given for `g`, then `step_size_f` is also used for `g`. step_size_g (optional): The step size used by CGA for `g`. Like `step_size_f`, this can be a scalar or a callable. If no step size is given for `g`, then `step_size_f` is used. linear_op_solver (callable, optional): This is a function which outputs the solution to `x = Ax + b` when given a callable linear operator representing the matrix-vector product `Ax` and an array `b`. If `None` is given, then a simple fixed point iteration solver is used. batched_iter_size (int, optional): The number of iterations to be unrolled and executed per iterations of `while_loop` op. Convergence is only tested at the beginning of each batch. Set this to a number larger than 1 to reduce the number of times convergence is checked and to potentially allow for the graph of the unrolled batch to be more aggressively optimized. unroll (bool, optional): If True, use `jax.lax.scan` instead of `jax.lax.while`. This enables back-propagating through the iterations. NOTE: due to current limitations in `JAX`, when `unroll` is `True`, convergence is ignored and the loop always runs for the maximum number of iterations. use_full_matrix (bool, optional): Use a CGA implementation which uses full hessians instead of potentially more efficient jacobian-vector products. This is useful for debugging and might provide a small performance boost when the dimensions are small. If set to True, then, if provided, the `linear_op_solver` is ignored. solve_order (str, optional): Specifies how the updates for each player are solved for. Should be one of - 'both' (default): Solves the linear system for each player (eq. 3 of Schaefer 2019) - 'yx' : Solves for the player behind `y` then updates `x` - 'xy' : Solves for the player behind `x` then updates `y` - 'alternate': Solves for `x` update `y`, next iteration solves for y and update `x` Defaults to 'both' Returns: FixedPointSolution: A named tuple containing the results of the optimization. The tuple contains the attributes `value` (the final solution tuple), `converged` (a bool indicating whether convergence was achieved), `iterations` (the number of iterations used), and `previous_value` (the value of the solution on the previous iteration). The previous value satisfies `sol.value=step_cga(sol.previous_value)` and allows us to log the size of the last step if desired. """ if use_full_matrix: cga_init, cga_update, get_params = full_solve_cga( step_size_f=step_size_f, step_size_g=step_size_g or step_size_f, f=f, g=g, ) else: cga_init, cga_update, get_params = cga( step_size_f=step_size_f, step_size_g=step_size_g or step_size_f, f=f, g=g, linear_op_solver=linear_op_solver, solve_order=solve_order) grad_yg = jax.grad(g, 1) grad_xf = jax.grad(f, 0) def step(i, inputs): x, y = inputs[:2] grads = (grad_xf(x, y), grad_yg(x, y)) return cga_update(i, grads, inputs) def cga_convergence_test(x_new, x_old): return convergence_test(x_new[:2], x_old[:2]) solution = loop.fixed_point_iteration( init_x=cga_init(init_values), func=step, convergence_test=cga_convergence_test, max_iter=max_iter, batched_iter_size=batched_iter_size, unroll=unroll, ) return solution._replace( value=get_params(solution.value), previous_value=get_params(solution.previous_value), )
def implicit_ecp( objective, equality_constraints, initial_values, lr_func, max_iter=500, convergence_test=default_convergence_test, batched_iter_size=1, optimizer=optimizers.sgd, tol=1e-6): """Use implicit differentiation to solve a nonlinear equality-constrained program of the form: max f(x, θ) subject to h(x, θ) = 0 . We perform a change of variable via the implicit function theorem and obtain the unconstrained program: max f(φ(θ), θ) , where φ is an implicit function of the parameters θ such that h(φ(θ), θ) = 0. Args: objective (callable): Binary callable with signature `f(x, θ)` equality_constraints (callble): Binary callable with signature `h(x, θ)` initial_values (tuple): Tuple of initial values `(x_0, θ_0)` lr_func (scalar or callable): The step size used by the unconstrained optimizer. This can be a scalar ora callable taking in the current iteration and returning a scalar. max_iter (int, optional): Maximum number of outer iterations. Defaults to 500. convergence_test (callable): Binary callable with signature `callback(new_state, old_state)` where `new_state` and `old_state` are tuples of the form `(x_k^*, θ_k)` such that `h(x_k^*, θ_k) = 0` (and with `k-1` for `old_state`). The default convergence test returns `true` if both elements of the tuple have not changed within some tolerance. batched_iter_size (int, optional): The number of iterations to be unrolled and executed per iterations of the `while_loop` op for the forward iteration and the fixed-point adjoint iteration. Defaults to 1. optimizer (callable, optional): Unary callable waking a `lr_func` as a argument and returning an unconstrained optimizer. Defaults to `jax.experimental.optimizers.sgd`. tol (float, optional): Tolerance for the forward and backward iterations. Defaults to 1e-6. Returns: fax.loop.FixedPointSolution: A named tuple containing the solution `(x, θ)` as as the `value` attribute, `converged` (a bool indicating whether convergence was achieved), `iterations` (the number of iterations used), and `previous_value` (the value of the solution on the previous iteration). The previous value satisfies `sol.value=func(sol.previous_value)` and allows us to log the size of the last step if desired. """ def _objective(*args): return -objective(*args) def make_fp_operator(params): def _fp_operator(i, x): del i return x + equality_constraints(x, params) return _fp_operator constraints_solver = make_forward_fixed_point_iteration( make_fp_operator, default_max_iter=max_iter, default_batched_iter_size=batched_iter_size, default_atol=tol, default_rtol=tol) adjoint_iteration_vjp = make_adjoint_fixed_point_iteration( make_fp_operator, default_max_iter=max_iter, default_batched_iter_size=batched_iter_size, default_atol=tol, default_rtol=tol) opt_init, opt_update, get_params = optimizer(step_size=lr_func) grad_objective = grad(_objective, (0, 1)) def update(i, values): old_xstar, opt_state = values old_params = get_params(opt_state) forward_solution = constraints_solver(old_xstar, old_params) grads_x, grads_params = grad_objective(forward_solution.value, get_params(opt_state)) ybar, _ = adjoint_iteration_vjp( grads_x, forward_solution, old_xstar, old_params) implicit_grads = tree_util.tree_multimap( lax.add, grads_params, ybar) opt_state = opt_update(i, implicit_grads, opt_state) return forward_solution.value, opt_state def _convergence_test(new_state, old_state): x_new, params_new = new_state[0], get_params(new_state[1]) x_old, params_old = old_state[0], get_params(old_state[1]) return convergence_test((x_new, params_new), (x_old, params_old)) x0, init_params = initial_values opt_state = opt_init(init_params) solution = fixed_point_iteration(init_x=(x0, opt_state), func=update, convergence_test=jit(_convergence_test), max_iter=max_iter, batched_iter_size=batched_iter_size, unroll=False) return solution._replace( value=(solution.value[0], get_params(solution.value[1])), previous_value=(solution.previous_value[0], get_params(solution.previous_value[1])), )
def cga_ecp( objective, equality_constraints, initial_values, lr_func, lr_multipliers=None, linear_op_solver=None, solve_order='alternating', max_iter=500, convergence_test=default_convergence_test, batched_iter_size=1, ): """Use CGA to solve a nonlinear equality-constrained program of the form: max f(x, θ) subject to h(x, θ) = 0 . We form the lagrangian L(x, θ, λ) = f(x, θ) - λ^⊤ h(x, θ) and try to find a saddle-point in: max_{x, θ} min_λ L(x, θ, λ) Args: objective (callable): Binary callable with signature `f(x, θ)` equality_constraints (callble): Binary callable with signature `h(x, θ)` initial_values (tuple): Tuple of initial values `(x_0, θ_0)` lr_func (scalar or callable): The step size used by CGA for `f`. This can be a scalar or a callable taking in the current iteration and returning a scalar. lr_multipliers (scalar or callable, optional): Step size for the dual updates. Defaults to None. If no step size is given for `lr_multipliers`, then `lr_func` is also used for `lr_multipliers`. linear_op_solver (callable, optional): This is a function which outputs the solution to `x = Ax + b` when given a callable linear operator representing the matrix-vector product `Ax` and an array `b`. If `None` is given, then a simple fixed point iteration solver is used. Used to solve for the matrix inverses in the CGA algorithm solve_order (str, optional): Specifies how the updates for each player are solved for. Should be one of - 'both' (default): Solves the linear system for each player (eq. 3 of Schaefer 2019) - 'yx' : Solves for the player behind `y` then updates `x` - 'xy' : Solves for the player behind `x` then updates `y` - 'alternate': Solves for `x` update `y`, next iteration solves for y and update `x` Defaults to 'both' max_iter (int): Maximum number of outer iterations. Defaults to 500. convergence_test (callable): Binary callable with signature `callback(new_state, old_state)` where `new_state` and `old_state` are nested tuples of the form `((x_k, θ_k), λ_k)` The default convergence test returns `true` if all elements of the tuple have not changed within some tolerance. batched_iter_size (int, optional): The number of iterations to be unrolled and executed per iterations of the `while_loop` op for the forward iteration and the fixed-point adjoint iteration. Defaults to 1. Returns: fax.loop.FixedPointSolution: A named tuple containing the solution `(x, θ)` as as the `value` attribute, `converged` (a bool indicating whether convergence was achieved), `iterations` (the number of iterations used), and `previous_value` (the value of the solution on the previous iteration). The previous value satisfies `sol.value=func(sol.previous_value)` and allows us to log the size of the last step if desired. """ def _objective(variables): return -objective(*variables) def _equality_constraints(variables): return -equality_constraints(*variables) init_mult, lagrangian, _ = make_lagrangian(_objective, _equality_constraints) lagrangian_variables = init_mult(initial_values) if lr_multipliers is None: lr_multipliers = lr_func opt_init, opt_update, get_params = cga_lagrange_min( lagrangian, lr_func, lr_multipliers, linear_op_solver, solve_order) @jit def update(i, opt_state): grads = grad(lagrangian, (0, 1))(*get_params(opt_state)) return opt_update(i, grads, opt_state) solution = fixed_point_iteration(init_x=opt_init(lagrangian_variables), func=update, convergence_test=jit(convergence_test), max_iter=max_iter, batched_iter_size=batched_iter_size, unroll=False) return solution._replace( value=get_params(solution.value)[0], previous_value=get_params(solution.previous_value)[0], )