Example #1
0
    def testUnrollFixedpointLoop(self):
        max_steps = 10

        def step(x):
            return x - 1

        init_x = np.zeros(())
        unroll_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda *args: False,
            max_iter=max_steps,
            batched_iter_size=1,
            unroll=True,
        )

        loop_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda *args: False,
            max_iter=max_steps,
            batched_iter_size=1,
            unroll=False,
        )

        testing.assert_array_equal(unroll_sol, loop_sol)
Example #2
0
    def testBatchedLoop(self, unroll):
        max_steps = 10

        def step(i, x):
            del i
            return x - 1

        init_x = np.zeros(())
        batched_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda *args: False,
            max_iter=max_steps,
            batched_iter_size=5,
            unroll=unroll,
        )

        loop_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda *args: False,
            max_iter=max_steps,
            batched_iter_size=1,
            unroll=unroll,
        )

        testing.assert_array_equal(batched_sol, loop_sol)
Example #3
0
    def testTermination(self, unroll):
        if unroll:
            self.skipTest(
                ("Can't terminate early when unrolling until `jax.lax.cond` "
                 "supports differentiation."))
        max_steps = 10

        def step(x):
            return x - 1

        init_x = np.zeros(())
        term_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda x, *args: x <= -5,
            max_iter=max_steps,
            batched_iter_size=1,
            unroll=unroll,
        )

        last_i, fixed_value = loop.unrolled(
            0,
            init_x=init_x,
            func=step,
            num_iter=5,
            return_last_two=False,
        )

        self.assertEqual(fixed_value, term_sol.value)
        self.assertEqual(last_i, term_sol.iterations)
Example #4
0
    def _adjoint_iteration_vjp(g, ans, init_xs, params):
        dvalue = g
        del init_xs
        init_dxs = dvalue

        fp_vjp_fn = jax.vjp(partial(flat_func, ans.iterations), ans.value,
                            params)[1]

        def dfp_fn(i, dout):
            del i
            dout = fp_vjp_fn(dout)[0] + dvalue
            return dout

        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   init_dxs.dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        dsol = loop.fixed_point_iteration(
            init_x=init_dxs,
            func=dfp_fn,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
        )

        return fp_vjp_fn(dsol.value)[1], dsol
Example #5
0
 def none_divisible_batch():
     init_x = np.zeros(())
     return loop.fixed_point_iteration(
         init_x=init_x,
         func=step,
         convergence_test=lambda *args: False,
         max_iter=max_steps,
         batched_iter_size=3,
     )
Example #6
0
 def run_unrolled(x):
     return loop.fixed_point_iteration(
         init_x=x,
         func=step,
         convergence_test=converge_test,
         max_iter=max_steps,
         batched_iter_size=1,
         unroll=True,
     ).value
Example #7
0
 def run_loop(x):
     return loop.fixed_point_iteration(
         init_x=x,
         func=step,
         convergence_test=lambda *args: False,
         max_iter=max_steps,
         batched_iter_size=1,
         unroll=False,
     )
Example #8
0
 def solve(x):
     return loop.fixed_point_iteration(
         init_x=x,
         func=step,
         convergence_test=convergence_test,
         max_iter=max_steps,
         batched_iter_size=1,
         unroll=unroll,
     )
Example #9
0
        def default_solver(linear_op, bvec, init_x=None):
            if init_x is None:
                init_x = bvec

            def _step_default_solver(i, x):
                del i
                return tree_util.tree_multimap(lax.add, linear_op(x), bvec)

            return loop.fixed_point_iteration(
                init_x=init_x,
                func=_step_default_solver,
                convergence_test=default_convergence_test,
                max_iter=default_max_iter,
            )
Example #10
0
    def testNoneMaxIter(self):
        max_steps = None

        def step(x):
            return x + 1

        init_x = np.zeros(())
        loop_sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=lambda i, *args: i > 10,
            max_iter=max_steps,
            batched_iter_size=3,
        )
        self.assertIsNotNone(loop_sol)
        loop_sol.value.block_until_ready()
Example #11
0
    def _default_solver(init_x, params):
        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   init_x.dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        func = param_func(params)
        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=func,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
        )

        return sol
Example #12
0
    def _default_solve(param_func, init_x, params):

        _convergence_test = convergence_test
        if convergence_test is None:
            _convergence_test = default_convergence_test(
                dtype=converge.tree_smallest_float_dtype(init_x),
            )

        func = param_func(params)
        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=func,
            convergence_test=_convergence_test,
            max_iter=max_iter,
            batched_iter_size=batched_iter_size,
        )

        return sol.value
Example #13
0
    def fixed_point_iteration_solver(init_x, params):
        dtype = converge.tree_smallest_float_dtype(init_x)
        rtol, atol = converge.adjust_tol_for_dtype(default_rtol, default_atol,
                                                   dtype)

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        func = param_func(params)
        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=func,
            convergence_test=convergence_test,
            max_iter=default_max_iter,
            batched_iter_size=default_batched_iter_size,
            unroll=unroll,
        )

        return sol.value
Example #14
0
    def testFixedPointDiverge(self, unroll):
        rtol = atol = 1e-10
        max_steps = 10

        init_x = np.zeros(())

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        def step(x_old):
            return x_old + 1

        sol = loop.fixed_point_iteration(
            init_x=init_x,
            func=step,
            convergence_test=convergence_test,
            max_iter=max_steps,
            batched_iter_size=1,
            unroll=unroll,
        )
        self.assertFalse(sol.converged)
        self.assertEqual(sol.iterations, max_steps)
Example #15
0
def conjugate_gradient_solve(linear_op,
                             bvec,
                             init_x,
                             max_iter=1000,
                             atol=1e-10):
    dtype = converge.tree_smallest_float_dtype(bvec)
    _, atol = converge.adjust_tol_for_dtype(0., atol=atol, dtype=dtype)
    init_r = tree_util.tree_multimap(lax.sub, bvec, linear_op(init_x))
    init_p = init_r
    init_r_sqr = math.pytree_dot(init_r, init_r)

    def convergence_test(state_new, state_old):
        del state_old
        return state_new[2] < atol

    solution = loop.fixed_point_iteration((init_x, init_r, init_r_sqr, init_p),
                                          func=partial(cg_step, linear_op),
                                          convergence_test=convergence_test,
                                          max_iter=max_iter)
    return solution._replace(
        value=solution.value[0],
        previous_value=solution.value[0],
    )
Example #16
0
def cga_iteration(init_values,
                  f,
                  g,
                  convergence_test,
                  max_iter,
                  step_size_f,
                  step_size_g=None,
                  linear_op_solver=None,
                  batched_iter_size=1,
                  unroll=False,
                  use_full_matrix=False,
                  solve_order='both'):
    """Run competitive gradient ascent until convergence or some max iteration.

    Use this function to find a fixed point of the competitive gradient
    ascent (CGA) update by repeatedly applying CGA to a candidate solution.
    This is done until the solution converges or until the maximum number of
    iterations, `max_iter` is reached.

    NOTE: if the maximum number of iterations is reached, the convergence
    will not be checked on the final application of `func` and the solution
    will always be marked as not converged when `unroll` is `False`.

    Args:
        init_values: a tuple of type `(a, b)` corresponding to the types
            accepted by `f` and `g`.
        f (callable): The function we which to maximize with type
            `a, b -> float`.
        g (callable): The "opposing" function which is also maximized with type
            `a, b -> float`.
        convergence_test (callable): A two argument function of type
            `(a,b), (a, b) -> bool` that takes in the newest solution and the
            previous solution and returns `True` if they have converged. The
            optimization will stop and return when `True` is returned.
        max_iter (int or None): The maximum number of iterations.
        step_size_f: The step size used by CGA for `f`. This can be a scalar or
            a callable taking in the current iteration and returning a scalar.
            If no step size is given for `g`, then `step_size_f` is also used
            for `g`.
        step_size_g (optional): The step size used by CGA for `g`. Like
            `step_size_f`, this can be a scalar or a callable. If no step size
            is given for `g`, then `step_size_f` is used.
        linear_op_solver (callable, optional): This is a function which outputs
            the solution to `x = Ax + b` when given a callable linear operator
            representing the matrix-vector product `Ax` and an array `b`. If
            `None` is given, then a simple fixed point iteration solver is used.
        batched_iter_size (int, optional): The number of iterations to be
            unrolled and executed per iterations of `while_loop` op. Convergence
            is only tested at the beginning of each batch. Set this to a number
            larger than 1 to reduce the number of times convergence is checked
            and to potentially allow for the graph of the unrolled batch to be
            more aggressively optimized.
        unroll (bool, optional): If True, use `jax.lax.scan` instead of 
            `jax.lax.while`. This enables back-propagating through the iterations.

            NOTE: due to current limitations in `JAX`, when `unroll` is `True`,
            convergence is ignored and the loop always runs for the maximum
            number of iterations.
        use_full_matrix (bool, optional): Use a CGA implementation which uses
            full hessians instead of potentially more efficient jacobian-vector
            products. This is useful for debugging and might provide a small
            performance boost when the dimensions are small. If set to True,
            then, if provided, the `linear_op_solver` is ignored.
        solve_order (str, optional): Specifies how the updates for each player are solved for.
            Should be one of

            - 'both' (default): Solves the linear system for each player (eq. 3 of Schaefer 2019)
            - 'yx' : Solves for the player behind `y` then updates `x`
            - 'xy' : Solves for the player behind `x` then updates `y`
            - 'alternate': Solves for `x` update `y`, next iteration solves for y and update `x`

            Defaults to 'both'

    Returns:
        FixedPointSolution: A named tuple containing the results of the
            optimization. The tuple contains the attributes `value`
            (the final solution tuple), `converged` (a bool indicating whether
            convergence was achieved), `iterations` (the number of iterations
            used), and `previous_value` (the value of the solution on the
            previous iteration). The previous value satisfies
            `sol.value=step_cga(sol.previous_value)` and allows us to log the
            size of the last step if desired.
    """

    if use_full_matrix:
        cga_init, cga_update, get_params = full_solve_cga(
            step_size_f=step_size_f,
            step_size_g=step_size_g or step_size_f,
            f=f,
            g=g,
        )
    else:
        cga_init, cga_update, get_params = cga(
            step_size_f=step_size_f,
            step_size_g=step_size_g or step_size_f,
            f=f,
            g=g,
            linear_op_solver=linear_op_solver,
            solve_order=solve_order)

    grad_yg = jax.grad(g, 1)
    grad_xf = jax.grad(f, 0)

    def step(i, inputs):
        x, y = inputs[:2]
        grads = (grad_xf(x, y), grad_yg(x, y))
        return cga_update(i, grads, inputs)

    def cga_convergence_test(x_new, x_old):
        return convergence_test(x_new[:2], x_old[:2])

    solution = loop.fixed_point_iteration(
        init_x=cga_init(init_values),
        func=step,
        convergence_test=cga_convergence_test,
        max_iter=max_iter,
        batched_iter_size=batched_iter_size,
        unroll=unroll,
    )
    return solution._replace(
        value=get_params(solution.value),
        previous_value=get_params(solution.previous_value),
    )
Example #17
0
def implicit_ecp(
        objective, equality_constraints, initial_values, lr_func, max_iter=500,
        convergence_test=default_convergence_test, batched_iter_size=1, optimizer=optimizers.sgd,
        tol=1e-6):
    """Use implicit differentiation to solve a nonlinear equality-constrained program of the form:

    max f(x, θ) subject to h(x, θ) = 0 .

    We perform a change of variable via the implicit function theorem and obtain the unconstrained
    program:

    max f(φ(θ), θ) ,

    where φ is an implicit function of the parameters θ such that h(φ(θ), θ) = 0.

    Args:
        objective (callable): Binary callable with signature `f(x, θ)`
        equality_constraints (callble): Binary callable with signature `h(x, θ)`
        initial_values (tuple): Tuple of initial values `(x_0, θ_0)`
        lr_func (scalar or callable): The step size used by the unconstrained optimizer. This can
            be a scalar ora callable taking in the current iteration and returning a scalar.
        max_iter (int, optional): Maximum number of outer iterations. Defaults to 500.
        convergence_test (callable): Binary callable with signature `callback(new_state, old_state)`
            where `new_state` and `old_state` are tuples of the form `(x_k^*, θ_k)` such that
            `h(x_k^*, θ_k) = 0` (and with `k-1` for `old_state`). The default convergence test
            returns `true` if both elements of the tuple have not changed within some tolerance.
        batched_iter_size (int, optional):  The number of iterations to be
            unrolled and executed per iterations of the `while_loop` op for the forward iteration
            and the fixed-point adjoint iteration. Defaults to 1.
        optimizer (callable, optional): Unary callable waking a `lr_func` as a argument and
            returning an unconstrained optimizer. Defaults to `jax.experimental.optimizers.sgd`.
        tol (float, optional): Tolerance for the forward and backward iterations. Defaults to 1e-6.

    Returns:
        fax.loop.FixedPointSolution: A named tuple containing the solution `(x, θ)` as as the
            `value` attribute, `converged` (a bool indicating whether convergence was achieved),
            `iterations` (the number of iterations used), and `previous_value`
            (the value of the solution on the previous iteration). The previous value satisfies
            `sol.value=func(sol.previous_value)` and allows us to log the size
            of the last step if desired.
    """
    def _objective(*args):
        return -objective(*args)

    def make_fp_operator(params):
        def _fp_operator(i, x):
            del i
            return x + equality_constraints(x, params)
        return _fp_operator

    constraints_solver = make_forward_fixed_point_iteration(
        make_fp_operator, default_max_iter=max_iter, default_batched_iter_size=batched_iter_size,
        default_atol=tol, default_rtol=tol)

    adjoint_iteration_vjp = make_adjoint_fixed_point_iteration(
        make_fp_operator, default_max_iter=max_iter, default_batched_iter_size=batched_iter_size,
        default_atol=tol, default_rtol=tol)

    opt_init, opt_update, get_params = optimizer(step_size=lr_func)

    grad_objective = grad(_objective, (0, 1))

    def update(i, values):
        old_xstar, opt_state = values
        old_params = get_params(opt_state)

        forward_solution = constraints_solver(old_xstar, old_params)

        grads_x, grads_params = grad_objective(forward_solution.value, get_params(opt_state))

        ybar, _ = adjoint_iteration_vjp(
            grads_x, forward_solution, old_xstar, old_params)

        implicit_grads = tree_util.tree_multimap(
            lax.add, grads_params, ybar)

        opt_state = opt_update(i, implicit_grads, opt_state)

        return forward_solution.value, opt_state

    def _convergence_test(new_state, old_state):
        x_new, params_new = new_state[0], get_params(new_state[1])
        x_old, params_old = old_state[0], get_params(old_state[1])
        return convergence_test((x_new, params_new), (x_old, params_old))

    x0, init_params = initial_values
    opt_state = opt_init(init_params)

    solution = fixed_point_iteration(init_x=(x0, opt_state),
                                     func=update,
                                     convergence_test=jit(_convergence_test),
                                     max_iter=max_iter,
                                     batched_iter_size=batched_iter_size,
                                     unroll=False)
    return solution._replace(
        value=(solution.value[0], get_params(solution.value[1])),
        previous_value=(solution.previous_value[0], get_params(solution.previous_value[1])),
    )
Example #18
0
def cga_ecp(
        objective, equality_constraints, initial_values, lr_func, lr_multipliers=None,
        linear_op_solver=None, solve_order='alternating', max_iter=500,
        convergence_test=default_convergence_test, batched_iter_size=1,
):
    """Use CGA to solve a nonlinear equality-constrained program of the form:

    max f(x, θ) subject to h(x, θ) = 0 .

    We form the lagrangian L(x, θ, λ) = f(x, θ) - λ^⊤ h(x, θ) and try to find a saddle-point in:

    max_{x, θ} min_λ L(x, θ, λ)

    Args:
        objective (callable): Binary callable with signature `f(x, θ)`
        equality_constraints (callble): Binary callable with signature `h(x, θ)`
        initial_values (tuple): Tuple of initial values `(x_0, θ_0)`
        lr_func (scalar or callable): The step size used by CGA for `f`. This can be a scalar or
            a callable taking in the current iteration and returning a scalar.
        lr_multipliers (scalar or callable, optional): Step size for the dual updates.
            Defaults to None. If no step size is given for `lr_multipliers`, then
            `lr_func` is also used for `lr_multipliers`.
        linear_op_solver (callable, optional): This is a function which outputs
            the solution to `x = Ax + b` when given a callable linear operator
            representing the matrix-vector product `Ax` and an array `b`. If
            `None` is given, then a simple fixed point iteration solver is used. Used to solve for
            the matrix inverses in the CGA algorithm
        solve_order (str, optional): Specifies how the updates for each player are solved for.
            Should be one of

            - 'both' (default): Solves the linear system for each player (eq. 3 of Schaefer 2019)
            - 'yx' : Solves for the player behind `y` then updates `x`
            - 'xy' : Solves for the player behind `x` then updates `y`
            - 'alternate': Solves for `x` update `y`, next iteration solves for y and update `x`

            Defaults to 'both'
        max_iter (int): Maximum number of outer iterations. Defaults to 500.
        convergence_test (callable): Binary callable with signature `callback(new_state, old_state)`
            where `new_state` and `old_state` are nested tuples of the form `((x_k, θ_k),  λ_k)`
            The default convergence test returns `true` if all elements of the tuple have not
            changed within some tolerance.
        batched_iter_size (int, optional):  The number of iterations to be
            unrolled and executed per iterations of the `while_loop` op for the forward iteration
            and the fixed-point adjoint iteration. Defaults to 1.

    Returns:
        fax.loop.FixedPointSolution: A named tuple containing the solution `(x, θ)` as as the
            `value` attribute, `converged` (a bool indicating whether convergence was achieved),
            `iterations` (the number of iterations used), and `previous_value`
            (the value of the solution on the previous iteration). The previous value satisfies
            `sol.value=func(sol.previous_value)` and allows us to log the size
            of the last step if desired.
    """
    def _objective(variables):
        return -objective(*variables)

    def _equality_constraints(variables):
        return -equality_constraints(*variables)

    init_mult, lagrangian, _ = make_lagrangian(_objective, _equality_constraints)
    lagrangian_variables = init_mult(initial_values)

    if lr_multipliers is None:
        lr_multipliers = lr_func

    opt_init, opt_update, get_params = cga_lagrange_min(
        lagrangian, lr_func, lr_multipliers, linear_op_solver, solve_order)

    @jit
    def update(i, opt_state):
        grads = grad(lagrangian, (0, 1))(*get_params(opt_state))
        return opt_update(i, grads, opt_state)

    solution = fixed_point_iteration(init_x=opt_init(lagrangian_variables),
                                     func=update,
                                     convergence_test=jit(convergence_test),
                                     max_iter=max_iter,
                                     batched_iter_size=batched_iter_size,
                                     unroll=False)
    return solution._replace(
        value=get_params(solution.value)[0],
        previous_value=get_params(solution.previous_value)[0],
    )