Example #1
0
 def step(key, state, init_key=None):
     transition_key, accept_key = random.split(key)
     next_state = st.init(inner_step)(init_key, transition_key,
                                      state)(transition_key, state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = next_state_log_prob - state_log_prob
     accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.)
     u = primitive.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
Example #2
0
    def MakeInputSignature(self, *in_shapes):
        """From a pytree of in_shape string specification, make a pytree of tf.TensorSpec.

    Dimension variables are replaced with None.
    """
        def in_shape_to_tensorspec(in_shape: str) -> tf.TensorSpec:
            in_spec = masking.parse_spec(in_shape)
            return tf.TensorSpec(tuple(
                int(dim_spec) if dim_spec.is_constant else None
                for dim_spec in in_spec),
                                 dtype=tf.float32)

        return tree_util.tree_multimap(in_shape_to_tensorspec, in_shapes)
Example #3
0
        def solve_delta_y(init_y):

            by = tree_util.tree_multimap(
                lambda z, grad_yg: grad_yg + eta_f * z, jvp_yxg(grad_xf),
                grad_yg)

            delta_y = linear_op_solver(
                linear_op=linear_op_y,
                bvec=by,
                init_x=init_y,
            )

            return delta_y
Example #4
0
  def get_samples(x1, x2, get):
    if x2 is not None:
      assert x1.shape[1:] == x2.shape[1:]

    _key = key
    for n in range(1, max(n_samples) + 1):
      _key, split = random.split(_key)
      one_sample = kernel_fn_sample_once(x1, x2, split, get)
      if n == 1:
        ker_sampled = one_sample
      else:
        ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample)
      yield n, ker_sampled
Example #5
0
 def get_samples(x1, x2, get, **apply_fn_kwargs):
     _key = key
     ker_sampled = None
     for n in range(1, max(n_samples) + 1):
         _key, split = random.split(_key)
         one_sample = kernel_fn_sample_once(x1, x2, split, get,
                                            **apply_fn_kwargs)
         if ker_sampled is None:
             ker_sampled = one_sample
         else:
             ker_sampled = tree_multimap(operator.add, ker_sampled,
                                         one_sample)
         yield n, ker_sampled
def private_train_step(model, optimizer, loss_fn, args, data):
    if args.no_vmap:
        x, y = data
        loss, clipped_grads = tree_multimap(
            lambda *xs: tf.stack(xs),
            *(compute_per_eg_grad(model, optimizer, loss_fn, args, (x[i], y[i]))
              for i in range(x.shape[0])))
    else:
        loss, clipped_grads = tf.vectorized_map(
            partial(compute_per_eg_grad, model, optimizer, loss_fn, args),
            data)  # , fallback_to_while_loop=False)
    final_grads = tf.nest.map_structure(partial(reduce_noise_normalize_batch, args), clipped_grads)
    optimizer.apply_gradients(zip(final_grads, model.trainable_variables))
    return loss
Example #7
0
 def linear_opt_min(min_tree):
     temp = hessian_yx(min_tree)  # returns max_tree type
     temp1 = _tree_apply(_tree_apply(breg_max.inv_D2P,
                                     prev_state.maxPlayer),
                         temp)  # returns max_tree type
     temp2 = hessian_xy(temp1)  # returns min_tree type
     temp3 = tree_util.tree_map(lambda x: eta_max * x,
                                temp2)  # still min_tree type
     temp4 = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer),
                         min_tree)  # also returns min_tree type
     temp5 = tree_util.tree_map(lambda x: 1 / eta_min * x, temp4)
     # print("linear operator being called! - min")
     out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5)
     return out  # min_tree type
Example #8
0
 def _map_split(nest, indices_or_sections):
     """Splits leaf nodes of nests and returns a list of nests."""
     if isinstance(indices_or_sections, int):
         n_lists = indices_or_sections
     else:
         n_lists = len(indices_or_sections) + 1
     concat = lambda field: np_.split(field, indices_or_sections)
     nest_of_lists = tree.tree_map(concat, nest)
     # pylint: disable=cell-var-from-loop
     list_of_nests = [
         tree.tree_multimap(lambda _, x: x[i], nest, nest_of_lists)
         for i in range(n_lists)
     ]
     return list_of_nests
Example #9
0
    def sum_and_contract(fx, j1, j2):
        ndim = fx.ndim
        size = utils.size_at(fx, trace_axes)

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
        _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

        def contract(x, y):
            param_axes = list(range(x.ndim))[ndim:]
            contract_axes = _trace_axes + param_axes
            return utils.dot_general(x, y, contract_axes,
                                     _diagonal_axes) / size

        return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
Example #10
0
def _test_lqr2(n):
    # pylint: disable=import-outside-toplevel
    import control
    from jax.tree_util import tree_multimap

    A = jp.zeros((n, n))
    B = jp.eye(n)
    Q = jp.eye(n)
    R = jp.eye(n)
    N = jp.zeros((n, n))

    actual = lqr_continuous_time_infinite_horizon(A, B, Q, R, N)
    expected = control.lqr(A, B, Q, R, N)
    assert tree_multimap(jp.allclose, actual, expected)
def fruity_loops(outer_loop_fn, inner_loop_fn, outer_loop_count, inner_loop_count, init):
  run = jit(lambda carry: lax.scan(inner_loop_fn, carry, jnp.arange(inner_loop_count)))
  last = jit(lambda seq: tree_map(itemgetter(-1), seq))

  history = []
  carry = init
  for _ in range(outer_loop_count):
    t0 = time.time()
    carry, seq = run(carry)
    seq_last = tree_map(lambda x: x.block_until_ready(), last(seq))
    history.append(seq)
    outer_loop_fn(carry, seq_last, elapsed=time.time() - t0)

  return carry, tree_multimap(lambda *args: jnp.concatenate(args), history[0], *history[1:])
Example #12
0
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
    out_tree, res_tree = out_trees()
    res, cts_out = split_list(args, [res_tree.num_leaves])
    py_res = tree_unflatten(res_tree, res)
    py_cts_out = tree_unflatten(out_tree, cts_out)
    py_cts_in = yield (py_res, py_cts_out), {}
    # For each None in py_cts_in, indicating an argument for which the rule
    # produces no cotangent, we replace it with a pytree with the structure of the
    # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
    # object, to be replaced with Nones in the final returned result.
    zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
    dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
    cts_in_flat = []
    append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
    try:
        if not isinstance(py_cts_in, tuple):
            raise ValueError
        tree_multimap(append_cts,
                      tuple(zero if ct is None else ct for ct in py_cts_in),
                      dummy)
    except ValueError:
        _, in_tree2 = tree_flatten(py_cts_in)
        msg = (
            "Custom VJP rule must produce an output with the same container "
            "(pytree) structure as the args tuple of the primal function, "
            "and in particular must produce a tuple of length equal to the "
            "number of arguments to the primal function, but got VJP output "
            "structure {} for primal input structure {}.")
        raise TypeError(msg.format(in_tree2, in_tree)) from None
    # Ignore any None cotangents, and any corresponding to inputs for which the
    # type doesn't equal the tangent type (i.e. float0s)
    # TODO(mattjj): change this to check if tangent type represents 0dim vspace
    yield [
        Zero(a.at_least_vspace())
        if ct is zero or a != a.at_least_vspace() else ct
        for a, ct in zip(in_avals, cts_in_flat)
    ]
Example #13
0
def scan(f, a, bs):
    if _DISABLE_CONTROL_FLOW_PRIM:
        length = tree_flatten(bs)[0][0].shape[0]
        for i in range(length):
            b = tree_map(lambda x: x[i], bs)
            a = f(a, b)
            a_out = tree_map(lambda x: x[None, ...], a)
            if i == 0:
                out = a_out
            else:
                out = tree_multimap(lambda x, y: np.concatenate((x, y)), out,
                                    a_out)
        return out
    else:
        return lax.scan(f, a, bs)
 def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get,
                 **apply_fn_kwargs):
     _key = key
     ker_sampled = None
     for n in range(1, max(n_samples) + 1):
         keys = random.split(_key)
         _key, split = keys[0], keys[1]
         one_sample = kernel_fn_sample_once(x1, x2, split, get,
                                            **apply_fn_kwargs)
         if ker_sampled is None:
             ker_sampled = one_sample
         else:
             ker_sampled = tree_multimap(operator.add, ker_sampled,
                                         one_sample)
         yield n, ker_sampled
Example #15
0
def cmd_step(prev_state,
             updates,
             breg_min=default_breg,
             breg_max=default_breg):
    """Equation (2). Take in the previous player positions and update to the next player position. Return a 1-step cmd update.

    Args:
        prev_state (Named tuples of vectors): The current position of the players given by tuple
                                             with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)'
        updates (Named tuples of vectors): The updates del_x,del_y computed from updates(...) with signature 'UpdateState(del_min, del_max)'
        breg_min (Named tuples of callable): Tuple of unary callables with signature
                                            'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])'
                                            where DP and DP_inv are unary callables with signatures
                                            `DP(x,*args, **kwargs)`,'DP_inv(x,*arg,**kwarg)' and
                                            D2P, D2P_inv are function of functions
                                            (Given an x, returning linear transformation function
                                            that can take in another vector to output hessian-vector product).
        breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'.

    Returns:
        Named tuple: the states of the players at current iteration - CMDState
    """
    temp_min = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer),
                           updates.del_min)
    temp_max = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer),
                           updates.del_max)

    dual_min = tree_util.tree_multimap(lambda x, y: x + y,
                                       prev_state.minPlayer_dual, temp_min)
    dual_max = tree_util.tree_multimap(lambda x, y: x + y,
                                       prev_state.maxPlayer_dual, temp_max)

    minP = _tree_apply(breg_min.DP_inv, dual_min)
    maxP = _tree_apply(breg_max.DP_inv, dual_max)

    return CMDState(minP, maxP, dual_min, dual_max)
Example #16
0
        def solve_delta_x(init_x):

            bx = tree_util.tree_multimap(
                lambda grad_xf, z: grad_xf + eta_g * z,
                grad_xf,
                jvp_xyf(grad_yg),
            )

            delta_x = linear_op_solver(
                linear_op=linear_op_x,
                bvec=bx,
                init_x=init_x,
            )

            return delta_x
Example #17
0
def apply_updates(params, updates):
    """Applies an update to the corresponding parameters.

  This is an (optional) utility functions that applies an update, and returns
  the updated parameters to the caller. The update itself is typically the
  result of applying any number of `chainable` transformations.

  Args:
    params: a tree of parameters.
    updates: a tree of updates, the tree structure and the shape of the leaf
    nodes must match that of `params`.

  Returns:
    Updated parameters, with same structure and shape as `params`.
  """
    return tree_multimap(lambda p, u: p + u, params, updates)
Example #18
0
  def test_grad_of_shared_layer(self):
    def template(x, init_key=None):
      layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')(
          init_key, x)
      x, layer = layer.call_and_update(x)
      x, layer = layer.call_and_update(x)
      state.assign(layer, name='scalar_mul')
      return x[0]
    net = state.init(template)(self._seed, jnp.ones(()))

    def loss(net, x):
      return net(x)
    g = jax.grad(loss)(net, jnp.ones(()))
    def add(x, y):
      return x + y
    net = tree_util.tree_multimap(add, net, g)
    np.testing.assert_array_equal(net(jnp.ones(())), 36.)
Example #19
0
    def _preprocess(self, data):
        """Reshapes input so that it can be distributed across multiple cores."""
        multi_inputs = data.copy()

        def add_core_dimension(x):
            if np.isscalar(x):
                return x
            if x.shape[0] % self._num_devices != 0:
                raise ValueError(
                    f'The batch size must be a multiple of the number of'
                    f' devices. Got batch size = {x.shape[0]} and number'
                    f' of devices = {self._num_devices}.')
            prefix = (self._num_devices, x.shape[0] // self._num_devices)
            return np.reshape(x, prefix + x.shape[1:])

        multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
        return multi_inputs
Example #20
0
def consensus(subposteriors, num_draws=None, diagonal=False, rng=None):
    """
    Merges subposteriors following consensus Monte Carlo algorithm.

    **References:**

    1. *Bayes and big data: The consensus Monte Carlo algorithm*,
       Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman,
       Edward I. George, Robert E. McCulloch

    :param list subposteriors: a list in which each element is a collection of samples.
    :param int num_draws: number of draws from the merged posterior.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :param jax.random.PRNGKey rng: source of the randomness, defaults to `jax.random.PRNGKey(0)`.
    :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns
        a collection of `num_draws` samples with the same data structure as each subposterior.
    """
    # stack subposteriors
    joined_subposteriors = tree_multimap(lambda *args: np.stack(args), *subposteriors)
    # shape of joined_subposteriors: n_subs x n_samples x sample_shape
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    if num_draws is not None:
        rng = random.PRNGKey(0) if rng is None else rng
        # randomly gets num_draws from subposteriors
        n_subs = len(subposteriors)
        n_samples = tree_flatten(subposteriors[0])[0][0].shape[0]
        # shape of draw_idxs: n_subs x num_draws x sample_shape
        draw_idxs = random.randint(rng, shape=(n_subs, num_draws), minval=0, maxval=n_samples)
        joined_subposteriors = vmap(lambda x, idx: x[idx])(joined_subposteriors, draw_idxs)

    if diagonal:
        # compute weights for each subposterior (ref: Section 3.1 of [1])
        weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(joined_subposteriors)
        normalized_weights = weights / np.sum(weights, axis=0)
        # get weighted samples
        samples_flat = np.einsum('ij,ikj->kj', normalized_weights, joined_subposteriors)
    else:
        weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(joined_subposteriors)
        normalized_weights = np.matmul(np.linalg.inv(np.sum(weights, axis=0)), weights)
        samples_flat = np.einsum('ijk,ilk->lj', normalized_weights, joined_subposteriors)

    # unravel_fn acts on 1 sample of a subposterior
    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
    return vmap(lambda x: unravel_fn(x))(samples_flat)
Example #21
0
    def test_grad_of_function_constant(self):
        def template(x):
            return x + np.ones_like(x)

        net = state.init(template)(self._seed, state.Shape(5))

        def loss(net, x):
            return net(x).sum()

        g = jax.grad(loss)(net, np.ones(5))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 2 * np.ones(5))
Example #22
0
    def update(i, values):
        old_xstar, opt_state = values
        old_params = get_params(opt_state)

        forward_solution = constraints_solver(old_xstar, old_params)

        grads_x, grads_params = grad_objective(forward_solution.value, get_params(opt_state))

        ybar, _ = adjoint_iteration_vjp(
            grads_x, forward_solution, old_xstar, old_params)

        implicit_grads = tree_util.tree_multimap(
            lax.add, grads_params, ybar)

        opt_state = opt_update(i, implicit_grads, opt_state)

        return forward_solution.value, opt_state
Example #23
0
def test():
    # We just check we can run the functions and that they return "something"
    print('testing fit-flax')
    N = 3
    D = 5
    C = 10
    model = ModelTest(nhidden=0, nclasses=C)
    rng = jax.random.PRNGKey(0)
    X = np.random.randn(N, D)
    y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C))
    batch = {'X': X, 'y': y}
    params = model.init(rng, X)['params']

    # test apply
    logprobs = model.apply({'params': params}, batch['X'])
    assert logprobs.shape == (N, C)

    # test loss
    labels = batch['y']
    loss = softmax_cross_entropy(logprobs, labels)
    assert loss.shape == ()

    # test test_fn
    metrics = eval_classifier(model, params, batch)
    assert np.allclose(loss, metrics['loss'])

    # test train_fn
    make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9)
    optimizer = make_optimizer.create(params)
    optimizer, metrics = update_classifier(model, optimizer, batch)

    # test fit_model
    num_steps = 2
    train_iter = make_iterator_from_batch(batch)
    test_iter = make_iterator_from_batch(batch)
    params_init = params
    params_new, history = fit_model(model, rng, num_steps, train_iter,
                                    test_iter)

    diff = tree_util.tree_multimap(lambda x, y: x - y, params_init, params_new)
    print(diff)
    norm = l2norm_sq(diff)
    assert norm > 0  # check that parameters have changed :)

    print(history)
    print('test passed')
Example #24
0
def _scan(f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]], init: _Carry,
          xs: Iterable[_Input]) -> Tuple[_Carry, _Output]:
    """Implements an unrolled version of scan.

  Based on `jax.lax.scan` and has a similar API.

  TODO(schsam): We introduce this function because lax.scan currently has a
  higher peak memory usage than the unrolled version. We will aim to swap this
  out for lax.scan when issue #1273 and related have been resolved.
  """
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys += [y]

    return carry, tree_multimap(lambda *y: np.stack(y), *ys)
Example #25
0
def _scan(f, init, xs, store_on_device):
    """Implements an unrolled version of scan.

  Based on `jax.lax.scan` and has an identical API.

  TODO(schsam): We introduce this function because lax.scan currently has a
  higher peak memory usage than the unrolled version. We will aim to swap this
  out for lax.scan when issue #1273 and related have been resolved.
  """

    stack = np.stack if store_on_device else jit(np.stack, backend='cpu')

    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys += [y]
    return carry, tree_multimap(lambda *y: stack(y), *ys)
Example #26
0
    def test_grad_of_shared_layer(self):
        def template(x, init_key=None):
            layer = state.init(ScalarMul(2 * np.ones(1)),
                               name='scalar_mul')(init_key, x)
            return layer(layer(x)).sum()

        net = state.init(template)(self._seed, state.Shape(()))

        def loss(net, x):
            return net(x)

        g = jax.grad(loss)(net, np.ones(()))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        onp.testing.assert_array_equal(net(np.ones(())), 36.)
Example #27
0
    def test_grad_of_function_with_literal(self):
        def template(x, init_key=None):
            # 1.0 behaves like a literal when tracing
            return ScalarMul(1.0)(x, init_key=init_key, name='scalar_mul')

        net = state.init(template)(self._seed, state.Shape(5))

        def loss(net, x):
            return net(x).sum()

        g = jax.grad(loss)(net, np.ones(5))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 6 * np.ones(5))
 def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get,
                 **apply_fn_kwargs):
     _key = stateless_uniform(shape=[2],
                              seed=key,
                              minval=None,
                              maxval=None,
                              dtype=tf.int32)
     ker_sampled = None
     for n in range(1, max(n_samples) + 1):
         _key, split = tf_split(_key)
         one_sample = kernel_fn_sample_once(x1, x2, split, get,
                                            **apply_fn_kwargs)
         if ker_sampled is None:
             ker_sampled = one_sample
         else:
             ker_sampled = tree_multimap(operator.add, ker_sampled,
                                         one_sample)
         yield n, ker_sampled
Example #29
0
def bijector_ildj_rule(incells, outcells, **params):
    """Inverse/ILDJ rule for bijectors."""
    incells = incells[1:]
    num_consts = len(incells) - params['num_args']
    const_incells, flat_incells = jax_util.split_list(incells, [num_consts])
    flat_inproxies = safe_map(_CellProxy, flat_incells)
    in_tree = params['in_tree']
    bijector_proxies, inproxy = tree_util.tree_unflatten(
        in_tree, flat_inproxies)
    flat_bijector_cells = [
        proxy.cell for proxy in tree_util.tree_leaves(bijector_proxies)
    ]
    if any(cell.is_unknown() for cell in flat_bijector_cells):
        return const_incells + flat_incells, outcells, False, None
    bijector = tree_util.tree_multimap(lambda x: x.cell.val, bijector_proxies)
    direction = params['direction']
    if direction == 'forward':
        forward_func = bijector.forward
        inv_func = bijector.inverse
        ildj_func = bijector.inverse_log_det_jacobian
    elif direction == 'inverse':
        forward_func = bijector.inverse
        inv_func = bijector.forward
        ildj_func = bijector.forward_log_det_jacobian
    else:
        raise ValueError('Bijector direction must be '
                         '"forward" or "inverse".')

    outcell, = outcells
    incell = inproxy.cell
    done = False
    if incell.is_unknown() and not outcell.is_unknown():
        val, ildj = outcell.val, outcell.ildj
        flat_incells = [
            InverseAndILDJ(inv_func(val), ildj + ildj_func(val, np.ndim(val)))
        ]
        done = True
        new_outcells = outcells
    elif outcell.is_unknown() and not incell.is_unknown():
        new_outcells = [InverseAndILDJ.new(forward_func(incell.val))]
        done = True
    new_incells = flat_bijector_cells + flat_incells
    return const_incells + new_incells, new_outcells, done, None
    def update_fn(grads, params, state):
        """Compute the update.

    Args:
      grads: pytree of ndarray
        Gradient values.
      params: pytree of ndarray
        Parameter values.
      state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    Returns:
      step: pytree of ndarray
        The step to be added to the parameter values.
      next_state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    """

        grad_acc, grad_sq_acc, idx = state

        def update_one(g, p, g_acc, g_sq_acc):
            s = jax_common.NAdamWParamState(g_acc, g_sq_acc)
            new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g)
            return new_x, new_s

        # the following flattens, applies a map, extracts values out via zip,
        # then unflattens.
        flat_gs, tree_def = tree_flatten(grads)
        flat_ps, _ = tree_flatten(params)
        flat_s0, _ = tree_flatten(grad_acc)
        flat_s1, _ = tree_flatten(grad_sq_acc)

        next_param_states = tree_multimap(update_one, flat_gs, flat_ps,
                                          flat_s0, flat_s1)

        flat_step, flat_next_ss = zip(*next_param_states)
        flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss)

        step = tree_unflatten(tree_def, flat_step)
        next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc)
        next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc)

        return step, (next_grad_acc, next_grad_sq_acc, idx + 1)