def step(key, state, init_key=None): transition_key, accept_key = random.split(key) next_state = st.init(inner_step)(init_key, transition_key, state)(transition_key, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = next_state_log_prob - state_log_prob accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.) u = primitive.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def MakeInputSignature(self, *in_shapes): """From a pytree of in_shape string specification, make a pytree of tf.TensorSpec. Dimension variables are replaced with None. """ def in_shape_to_tensorspec(in_shape: str) -> tf.TensorSpec: in_spec = masking.parse_spec(in_shape) return tf.TensorSpec(tuple( int(dim_spec) if dim_spec.is_constant else None for dim_spec in in_spec), dtype=tf.float32) return tree_util.tree_multimap(in_shape_to_tensorspec, in_shapes)
def solve_delta_y(init_y): by = tree_util.tree_multimap( lambda z, grad_yg: grad_yg + eta_f * z, jvp_yxg(grad_xf), grad_yg) delta_y = linear_op_solver( linear_op=linear_op_y, bvec=by, init_x=init_y, ) return delta_y
def get_samples(x1, x2, get): if x2 is not None: assert x1.shape[1:] == x2.shape[1:] _key = key for n in range(1, max(n_samples) + 1): _key, split = random.split(_key) one_sample = kernel_fn_sample_once(x1, x2, split, get) if n == 1: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def get_samples(x1, x2, get, **apply_fn_kwargs): _key = key ker_sampled = None for n in range(1, max(n_samples) + 1): _key, split = random.split(_key) one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs) if ker_sampled is None: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def private_train_step(model, optimizer, loss_fn, args, data): if args.no_vmap: x, y = data loss, clipped_grads = tree_multimap( lambda *xs: tf.stack(xs), *(compute_per_eg_grad(model, optimizer, loss_fn, args, (x[i], y[i])) for i in range(x.shape[0]))) else: loss, clipped_grads = tf.vectorized_map( partial(compute_per_eg_grad, model, optimizer, loss_fn, args), data) # , fallback_to_while_loop=False) final_grads = tf.nest.map_structure(partial(reduce_noise_normalize_batch, args), clipped_grads) optimizer.apply_gradients(zip(final_grads, model.trainable_variables)) return loss
def linear_opt_min(min_tree): temp = hessian_yx(min_tree) # returns max_tree type temp1 = _tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer), temp) # returns max_tree type temp2 = hessian_xy(temp1) # returns min_tree type temp3 = tree_util.tree_map(lambda x: eta_max * x, temp2) # still min_tree type temp4 = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer), min_tree) # also returns min_tree type temp5 = tree_util.tree_map(lambda x: 1 / eta_min * x, temp4) # print("linear operator being called! - min") out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5) return out # min_tree type
def _map_split(nest, indices_or_sections): """Splits leaf nodes of nests and returns a list of nests.""" if isinstance(indices_or_sections, int): n_lists = indices_or_sections else: n_lists = len(indices_or_sections) + 1 concat = lambda field: np_.split(field, indices_or_sections) nest_of_lists = tree.tree_map(concat, nest) # pylint: disable=cell-var-from-loop list_of_nests = [ tree.tree_multimap(lambda _, x: x[i], nest, nest_of_lists) for i in range(n_lists) ] return list_of_nests
def sum_and_contract(fx, j1, j2): ndim = fx.ndim size = utils.size_at(fx, trace_axes) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) _trace_axes = utils.canonicalize_axis(trace_axes, ndim) def contract(x, y): param_axes = list(range(x.ndim))[ndim:] contract_axes = _trace_axes + param_axes return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
def _test_lqr2(n): # pylint: disable=import-outside-toplevel import control from jax.tree_util import tree_multimap A = jp.zeros((n, n)) B = jp.eye(n) Q = jp.eye(n) R = jp.eye(n) N = jp.zeros((n, n)) actual = lqr_continuous_time_infinite_horizon(A, B, Q, R, N) expected = control.lqr(A, B, Q, R, N) assert tree_multimap(jp.allclose, actual, expected)
def fruity_loops(outer_loop_fn, inner_loop_fn, outer_loop_count, inner_loop_count, init): run = jit(lambda carry: lax.scan(inner_loop_fn, carry, jnp.arange(inner_loop_count))) last = jit(lambda seq: tree_map(itemgetter(-1), seq)) history = [] carry = init for _ in range(outer_loop_count): t0 = time.time() carry, seq = run(carry) seq_last = tree_map(lambda x: x.block_until_ready(), last(seq)) history.append(seq) outer_loop_fn(carry, seq_last, elapsed=time.time() - t0) return carry, tree_multimap(lambda *args: jnp.concatenate(args), history[0], *history[1:])
def _flatten_bwd(in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) cts_in_flat = [] append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) try: if not isinstance(py_cts_in, tuple): raise ValueError tree_multimap(append_cts, tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ( "Custom VJP rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " "number of arguments to the primal function, but got VJP output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None # Ignore any None cotangents, and any corresponding to inputs for which the # type doesn't equal the tangent type (i.e. float0s) # TODO(mattjj): change this to check if tangent type represents 0dim vspace yield [ Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace() else ct for a, ct in zip(in_avals, cts_in_flat) ]
def scan(f, a, bs): if _DISABLE_CONTROL_FLOW_PRIM: length = tree_flatten(bs)[0][0].shape[0] for i in range(length): b = tree_map(lambda x: x[i], bs) a = f(a, b) a_out = tree_map(lambda x: x[None, ...], a) if i == 0: out = a_out else: out = tree_multimap(lambda x, y: np.concatenate((x, y)), out, a_out) return out else: return lax.scan(f, a, bs)
def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get, **apply_fn_kwargs): _key = key ker_sampled = None for n in range(1, max(n_samples) + 1): keys = random.split(_key) _key, split = keys[0], keys[1] one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs) if ker_sampled is None: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def cmd_step(prev_state, updates, breg_min=default_breg, breg_max=default_breg): """Equation (2). Take in the previous player positions and update to the next player position. Return a 1-step cmd update. Args: prev_state (Named tuples of vectors): The current position of the players given by tuple with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)' updates (Named tuples of vectors): The updates del_x,del_y computed from updates(...) with signature 'UpdateState(del_min, del_max)' breg_min (Named tuples of callable): Tuple of unary callables with signature 'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])' where DP and DP_inv are unary callables with signatures `DP(x,*args, **kwargs)`,'DP_inv(x,*arg,**kwarg)' and D2P, D2P_inv are function of functions (Given an x, returning linear transformation function that can take in another vector to output hessian-vector product). breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'. Returns: Named tuple: the states of the players at current iteration - CMDState """ temp_min = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer), updates.del_min) temp_max = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer), updates.del_max) dual_min = tree_util.tree_multimap(lambda x, y: x + y, prev_state.minPlayer_dual, temp_min) dual_max = tree_util.tree_multimap(lambda x, y: x + y, prev_state.maxPlayer_dual, temp_max) minP = _tree_apply(breg_min.DP_inv, dual_min) maxP = _tree_apply(breg_max.DP_inv, dual_max) return CMDState(minP, maxP, dual_min, dual_max)
def solve_delta_x(init_x): bx = tree_util.tree_multimap( lambda grad_xf, z: grad_xf + eta_g * z, grad_xf, jvp_xyf(grad_yg), ) delta_x = linear_op_solver( linear_op=linear_op_x, bvec=bx, init_x=init_x, ) return delta_x
def apply_updates(params, updates): """Applies an update to the corresponding parameters. This is an (optional) utility functions that applies an update, and returns the updated parameters to the caller. The update itself is typically the result of applying any number of `chainable` transformations. Args: params: a tree of parameters. updates: a tree of updates, the tree structure and the shape of the leaf nodes must match that of `params`. Returns: Updated parameters, with same structure and shape as `params`. """ return tree_multimap(lambda p, u: p + u, params, updates)
def test_grad_of_shared_layer(self): def template(x, init_key=None): layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')( init_key, x) x, layer = layer.call_and_update(x) x, layer = layer.call_and_update(x) state.assign(layer, name='scalar_mul') return x[0] net = state.init(template)(self._seed, jnp.ones(())) def loss(net, x): return net(x) g = jax.grad(loss)(net, jnp.ones(())) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) np.testing.assert_array_equal(net(jnp.ones(())), 36.)
def _preprocess(self, data): """Reshapes input so that it can be distributed across multiple cores.""" multi_inputs = data.copy() def add_core_dimension(x): if np.isscalar(x): return x if x.shape[0] % self._num_devices != 0: raise ValueError( f'The batch size must be a multiple of the number of' f' devices. Got batch size = {x.shape[0]} and number' f' of devices = {self._num_devices}.') prefix = (self._num_devices, x.shape[0] // self._num_devices) return np.reshape(x, prefix + x.shape[1:]) multi_inputs = tree_multimap(add_core_dimension, multi_inputs) return multi_inputs
def consensus(subposteriors, num_draws=None, diagonal=False, rng=None): """ Merges subposteriors following consensus Monte Carlo algorithm. **References:** 1. *Bayes and big data: The consensus Monte Carlo algorithm*, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch :param list subposteriors: a list in which each element is a collection of samples. :param int num_draws: number of draws from the merged posterior. :param bool diagonal: whether to compute weights using variance or covariance, defaults to `False` (using covariance). :param jax.random.PRNGKey rng: source of the randomness, defaults to `jax.random.PRNGKey(0)`. :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns a collection of `num_draws` samples with the same data structure as each subposterior. """ # stack subposteriors joined_subposteriors = tree_multimap(lambda *args: np.stack(args), *subposteriors) # shape of joined_subposteriors: n_subs x n_samples x sample_shape joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors) if num_draws is not None: rng = random.PRNGKey(0) if rng is None else rng # randomly gets num_draws from subposteriors n_subs = len(subposteriors) n_samples = tree_flatten(subposteriors[0])[0][0].shape[0] # shape of draw_idxs: n_subs x num_draws x sample_shape draw_idxs = random.randint(rng, shape=(n_subs, num_draws), minval=0, maxval=n_samples) joined_subposteriors = vmap(lambda x, idx: x[idx])(joined_subposteriors, draw_idxs) if diagonal: # compute weights for each subposterior (ref: Section 3.1 of [1]) weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(joined_subposteriors) normalized_weights = weights / np.sum(weights, axis=0) # get weighted samples samples_flat = np.einsum('ij,ikj->kj', normalized_weights, joined_subposteriors) else: weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(joined_subposteriors) normalized_weights = np.matmul(np.linalg.inv(np.sum(weights, axis=0)), weights) samples_flat = np.einsum('ijk,ilk->lj', normalized_weights, joined_subposteriors) # unravel_fn acts on 1 sample of a subposterior _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0])) return vmap(lambda x: unravel_fn(x))(samples_flat)
def test_grad_of_function_constant(self): def template(x): return x + np.ones_like(x) net = state.init(template)(self._seed, state.Shape(5)) def loss(net, x): return net(x).sum() g = jax.grad(loss)(net, np.ones(5)) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 onp.testing.assert_array_equal(net(np.ones(5)), 2 * np.ones(5))
def update(i, values): old_xstar, opt_state = values old_params = get_params(opt_state) forward_solution = constraints_solver(old_xstar, old_params) grads_x, grads_params = grad_objective(forward_solution.value, get_params(opt_state)) ybar, _ = adjoint_iteration_vjp( grads_x, forward_solution, old_xstar, old_params) implicit_grads = tree_util.tree_multimap( lax.add, grads_params, ybar) opt_state = opt_update(i, implicit_grads, opt_state) return forward_solution.value, opt_state
def test(): # We just check we can run the functions and that they return "something" print('testing fit-flax') N = 3 D = 5 C = 10 model = ModelTest(nhidden=0, nclasses=C) rng = jax.random.PRNGKey(0) X = np.random.randn(N, D) y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C)) batch = {'X': X, 'y': y} params = model.init(rng, X)['params'] # test apply logprobs = model.apply({'params': params}, batch['X']) assert logprobs.shape == (N, C) # test loss labels = batch['y'] loss = softmax_cross_entropy(logprobs, labels) assert loss.shape == () # test test_fn metrics = eval_classifier(model, params, batch) assert np.allclose(loss, metrics['loss']) # test train_fn make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = make_optimizer.create(params) optimizer, metrics = update_classifier(model, optimizer, batch) # test fit_model num_steps = 2 train_iter = make_iterator_from_batch(batch) test_iter = make_iterator_from_batch(batch) params_init = params params_new, history = fit_model(model, rng, num_steps, train_iter, test_iter) diff = tree_util.tree_multimap(lambda x, y: x - y, params_init, params_new) print(diff) norm = l2norm_sq(diff) assert norm > 0 # check that parameters have changed :) print(history) print('test passed')
def _scan(f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]], init: _Carry, xs: Iterable[_Input]) -> Tuple[_Carry, _Output]: """Implements an unrolled version of scan. Based on `jax.lax.scan` and has a similar API. TODO(schsam): We introduce this function because lax.scan currently has a higher peak memory usage than the unrolled version. We will aim to swap this out for lax.scan when issue #1273 and related have been resolved. """ carry = init ys = [] for x in xs: carry, y = f(carry, x) ys += [y] return carry, tree_multimap(lambda *y: np.stack(y), *ys)
def _scan(f, init, xs, store_on_device): """Implements an unrolled version of scan. Based on `jax.lax.scan` and has an identical API. TODO(schsam): We introduce this function because lax.scan currently has a higher peak memory usage than the unrolled version. We will aim to swap this out for lax.scan when issue #1273 and related have been resolved. """ stack = np.stack if store_on_device else jit(np.stack, backend='cpu') carry = init ys = [] for x in xs: carry, y = f(carry, x) ys += [y] return carry, tree_multimap(lambda *y: stack(y), *ys)
def test_grad_of_shared_layer(self): def template(x, init_key=None): layer = state.init(ScalarMul(2 * np.ones(1)), name='scalar_mul')(init_key, x) return layer(layer(x)).sum() net = state.init(template)(self._seed, state.Shape(())) def loss(net, x): return net(x) g = jax.grad(loss)(net, np.ones(())) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) onp.testing.assert_array_equal(net(np.ones(())), 36.)
def test_grad_of_function_with_literal(self): def template(x, init_key=None): # 1.0 behaves like a literal when tracing return ScalarMul(1.0)(x, init_key=init_key, name='scalar_mul') net = state.init(template)(self._seed, state.Shape(5)) def loss(net, x): return net(x).sum() g = jax.grad(loss)(net, np.ones(5)) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 onp.testing.assert_array_equal(net(np.ones(5)), 6 * np.ones(5))
def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get, **apply_fn_kwargs): _key = stateless_uniform(shape=[2], seed=key, minval=None, maxval=None, dtype=tf.int32) ker_sampled = None for n in range(1, max(n_samples) + 1): _key, split = tf_split(_key) one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs) if ker_sampled is None: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def bijector_ildj_rule(incells, outcells, **params): """Inverse/ILDJ rule for bijectors.""" incells = incells[1:] num_consts = len(incells) - params['num_args'] const_incells, flat_incells = jax_util.split_list(incells, [num_consts]) flat_inproxies = safe_map(_CellProxy, flat_incells) in_tree = params['in_tree'] bijector_proxies, inproxy = tree_util.tree_unflatten( in_tree, flat_inproxies) flat_bijector_cells = [ proxy.cell for proxy in tree_util.tree_leaves(bijector_proxies) ] if any(cell.is_unknown() for cell in flat_bijector_cells): return const_incells + flat_incells, outcells, False, None bijector = tree_util.tree_multimap(lambda x: x.cell.val, bijector_proxies) direction = params['direction'] if direction == 'forward': forward_func = bijector.forward inv_func = bijector.inverse ildj_func = bijector.inverse_log_det_jacobian elif direction == 'inverse': forward_func = bijector.inverse inv_func = bijector.forward ildj_func = bijector.forward_log_det_jacobian else: raise ValueError('Bijector direction must be ' '"forward" or "inverse".') outcell, = outcells incell = inproxy.cell done = False if incell.is_unknown() and not outcell.is_unknown(): val, ildj = outcell.val, outcell.ildj flat_incells = [ InverseAndILDJ(inv_func(val), ildj + ildj_func(val, np.ndim(val))) ] done = True new_outcells = outcells elif outcell.is_unknown() and not incell.is_unknown(): new_outcells = [InverseAndILDJ.new(forward_func(incell.val))] done = True new_incells = flat_bijector_cells + flat_incells return const_incells + new_incells, new_outcells, done, None
def update_fn(grads, params, state): """Compute the update. Args: grads: pytree of ndarray Gradient values. params: pytree of ndarray Parameter values. state: A tuple of (gradient accumulators, squared gradient accumulators, idx) Returns: step: pytree of ndarray The step to be added to the parameter values. next_state: A tuple of (gradient accumulators, squared gradient accumulators, idx) """ grad_acc, grad_sq_acc, idx = state def update_one(g, p, g_acc, g_sq_acc): s = jax_common.NAdamWParamState(g_acc, g_sq_acc) new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g) return new_x, new_s # the following flattens, applies a map, extracts values out via zip, # then unflattens. flat_gs, tree_def = tree_flatten(grads) flat_ps, _ = tree_flatten(params) flat_s0, _ = tree_flatten(grad_acc) flat_s1, _ = tree_flatten(grad_sq_acc) next_param_states = tree_multimap(update_one, flat_gs, flat_ps, flat_s0, flat_s1) flat_step, flat_next_ss = zip(*next_param_states) flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss) step = tree_unflatten(tree_def, flat_step) next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc) next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc) return step, (next_grad_acc, next_grad_sq_acc, idx + 1)