def zoneout_decoder(inputs, prev_state): x, mask = inputs x, state = self.decoder(x, prev_state) state = jax.tree_multimap(lambda m, s1, s2: s1*m + s2*(1-m), mask, prev_state, state) return x, state
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, pdf=None, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) pdf: |ψ(x)|^2 if exact optimization is being used else None chunk_size: an int specfying the size of the chunks the gradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts centered_jacobian_fun = compose( stack_jacobian_tuple, partial(centered_jacobian_cplx, _build_fn=lambda *x: x), ) jacobian_fun = jacobian_cplx elif mode == "holomorphic": split_complex_params = False centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}' .format(mode)) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_real(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn if pdf is None: centered_oks = _divide_by_sqrt_n_samp( centered_jacobian_fun( f, params, samples, chunk_size=chunk_size, ), samples, ) else: oks = jacobian_fun(f, params, samples) oks_mean = jax.tree_map(partial(sum, axis=0), _multiply_by_pdf(oks, pdf)) centered_oks = jax.tree_multimap(lambda x, y: x - y, oks, oks_mean) centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf)) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def test_frozen_dict_partially_maps(self): x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}), freeze({'a': { 'b': 1 }})) self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
def test_prng_sequence_split(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) hk_keys = s.take(10) jax_keys = tuple(jax.random.split(k, num=11)[1:]) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
def incremental_update(new_tensors, old_tensors, tau: Numeric): """Incrementally update all elements from a nested struct.""" return jax.tree_multimap(lambda new, old: tau * new + (1.0 - tau) * old, new_tensors, old_tensors)
def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch)
def conditional_update(new_tensors: Any, old_tensors: Any, is_time: Numeric): """Checks whether to update the params and returns the correct params.""" return jax.tree_multimap( lambda new, old: jax.lax.select(is_time, new, old), new_tensors, old_tensors)
def inner(scope_fn, repack_fn, variable_groups, rng_groups, init, *args): def find_length(axis, x): if axis is not axes_scan.broadcast: leaves = jax.tree_leaves(x) if leaves: return leaves[0].shape[axis] return () # split rngs lengths = jax.tree_multimap(find_length, in_axes, args) lengths = set(jax.tree_leaves(lengths)) if length is None and len(lengths) == 1: d_length, = lengths elif len(lengths) > 1: raise ValueError(f'Inconsistent scan lengths: {lengths}') elif length is None: raise ValueError('length should be specified manually.') else: d_length = length split_fn = lambda rng: random.split(rng, d_length) rng_groups = tuple( jax.tree_map(split_fn, rng_group) if split else rng_group for rng_group, split in zip(rng_groups, rng_splits)) @functools.partial(axes_scan.scan, in_axes=(variable_in_axes, rng_axes, in_axes), out_axes=(out_axes, variable_out_axes), length=length, reverse=reverse) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups if data_transform is not None: variable_groups, rng_groups = data_transform( variable_groups, rng_groups) scope = scope_fn(variable_groups, rng_groups) c, y = fn(scope, c, *args) out_vars = repack_fn(scope) broadcast_vars_out = out_vars[0] carry_vars = out_vars[1] scan_vars = out_vars[2:] # add immutable broadcast vars back to broadcast output # otherwise they won't be fed to the actual scan body for in_group, out_group in zip(broadcast_vars, broadcast_vars_out): for col in in_group: if col not in out_group: out_group[col] = in_group[col] return broadcast_vars_out, (carry_vars, c), (y, scan_vars) broadcast_vars = variable_groups[0] carry_vars = variable_groups[1] scan_vars = variable_groups[2:] broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( broadcast_vars, (carry_vars, init), scan_vars, rng_groups, args) # remove immutable broadcast vars otherwise they will be updated # with their own value which will cause an error for out_group in broadcast_vars: for name, col in tuple(out_group.items()): if isinstance(col, FrozenDict): del out_group[name] out_vars = ( broadcast_vars, carry_vars, ) + scan_vars return (c, ys), out_vars
def tree_equal(x, y): return jax.tree_util.tree_all(jax.tree_multimap(lambda x, y: x == y, x, y))
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, first_available_dim=None): from numpyro.contrib.funsor import config_enumerate, enum, markov from numpyro.contrib.funsor import trace as packed_trace # amount number of steps to unroll history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: x0 = tree_map(lambda x: x[:unroll_steps], xs) xs_ = tree_map(lambda x: x[unroll_steps:], xs) carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config(f, config_fn=lambda msg: {'_scan_current_index': i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix='_PREV_' * (unroll_steps - i), divider=''): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), \ enum(first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) if i > 0: # reshape y1, y2,... to have the same shape as y0 y0 = tree_multimap(lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) y0s.append(y0) # shapes of the first `history - 1` steps are not useful to interpret the last carry # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append(jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]) else: # this is the last rolling step y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) wrapped_carry = device_put(wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - unroll_steps, reverse) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site['type'] not in ('sample',): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site['fn'].batch_shape), jnp.ndim(site['value']) - site['fn'].event_dim) site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) # then join with y0s ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] flatten_carry, treedef = tree_flatten(carry) flatten_carry = [jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape)] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def init_fn(params): mu = jax.tree_multimap(jnp.zeros_like, params) # momentum return ScaleByLarsState(mu=mu)
def tree_add(a, b): return jax.tree_multimap(lambda e1, e2: e1+e2, a, b)
def trees2_map(fun, tree1, tree2, *args, **kwargs): return tree_multimap(lambda x, y: fun(x, y, *args, **kwargs), tree1, tree2)
def __call__(self, inputs, state): """Run one step of the wrapped core, handling state reset. Args: inputs: Tuple with two elements, ``inputs, should_reset``, where ``should_reset`` is the signal used to reset the wrapped core's state. ``should_reset`` can be either tensor or nest. If nest, ``should_reset`` must match the state structure, and its components' shapes must be prefixes of the corresponding entries tensors' shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g. ``[batch_size]``. state: Previous wrapped core state. Returns: Tuple of the wrapped core's ``output, next_state``. """ inputs, should_reset = inputs if jax.treedef_is_leaf(jax.tree_structure(should_reset)): # Equivalent to not tree.is_nested, but with support for Jax extensible # pytrees. should_reset = jax.tree_map(lambda _: should_reset, state) # We now need to manually pad 'on the right' to ensure broadcasting operates # correctly. # Automatic broadcasting would in fact implicitly pad 'on the left', # resulting in the signal to trigger resets for parts of the state # across batch entries. For example: # # import jax # import jax.numpy as jnp # # shape = (2, 2, 2) # x = jnp.zeros(shape) # y = jnp.ones(shape) # should_reset = jnp.array([False, True]) # v = jnp.where(should_reset, x, y) # for batch_entry in range(shape[0]): # print("batch_entry {}:\n".format(batch_entry), v[batch_entry]) # # >> batch_entry 0: # >> [[1. 0.] # >> [1. 0.]] # >> batch_entry 1: # >> [[1. 0.] # >> [1. 0.]] # # Note how manually padding the should_reset tensor yields the desired # behavior. # # import jax # import jax.numpy as jnp # # shape = (2, 2, 2) # x = jnp.zeros(shape) # y = jnp.ones(shape) # should_reset = jnp.array([False, True]) # dims_to_add = x.ndim - should_reset.ndim # should_reset = should_reset.reshape(should_reset.shape + (1,)*dims_to_add) # v = jnp.where(should_reset, x, y) # for batch_entry in range(shape[0]): # print("batch_entry {}:\n".format(batch_entry), v[batch_entry]) # # >> batch_entry 0: # >> [[1. 1.] # >> [1. 1.]] # >> batch_entry 1: # >> [[0. 0.] # >> [0. 0.]] should_reset = jax.tree_multimap(_validate_and_conform, should_reset, state) if self._is_batched(state): batch_size = jax.tree_leaves(inputs)[0].shape[0] else: batch_size = None initial_state = jax.tree_multimap(lambda s, i: i.astype(s.dtype), state, self.initial_state(batch_size)) state = jax.tree_multimap(jnp.where, should_reset, initial_state, state) return self.core(inputs, state)
def stack_forest(forest): stack_args = lambda *args: onp.stack(args) return jax.tree_multimap(stack_args, *forest)
def update_fn(updates, state, params=None): del params multiplied_updates = jax.tree_multimap( lambda m, update: jax.tree_map(lambda u: u * m, update), state.multipliers, updates) return multiplied_updates, state
def concatenate_leaves(pytrees): return jax.tree_multimap( lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)
# inv_D2P_eq_min = lambda v: jax.tree_map(lambda x: x, v) inv_D2P_eq_min = lambda v: jax.tree_map(id_func, v) inv_D2P_ineq_min = lambda v: jax.tree_map(inv_D2P_pd, v) min_augmented_D2P_inv = (inv_D2P_eq_min, inv_D2P_ineq_min) key1 = random.PRNGKey(0) key = random.PRNGKey(1) x1 = jnp.array([1., 2., 3., 4., 5.]) x2 = random.normal(key1, (5, )) W1 = random.normal(key, (3, 3)) W2 = random.normal(key1, (3, 3)) x = ((x1, x2), (x1, x1, x2)) W = ((W1, W2), (W1, W1, W2)) print(jax.tree_multimap(lambda f, x: f(x), min_augmented_DP, x)) print(DP_pd(x1)) print(DP_pd(x2)) # Check if the inv(D2P) match the closed form. print(inv_D2P_pd(W1)(jnp.identity(W1.shape[0]))) print(jnp.linalg.matrix_power(W1, 2).T) print(inv_D2P_pd(x2)(x1)) print(jnp.dot(jnp.diag(x2), x1)) def make_bound_breg_original(lb=-1.0, ub=1.0): def breg_bound_internal(lb, ub, *args, **kwargs): return lambda vec: jnp.sum((-vec + ub) * jnp.log(-vec + ub) + (vec - lb) * jnp.log(vec - lb))
def warmup( self, rng_key: jax.random.PRNGKey, initial_state: HMCState, kernel_factory: Callable, num_chains, num_warmup_steps: int = 1000, accelerate=False, initial_step_size: float = 0.1, ) -> Tuple[HMCState, HMCParameters, Optional[StanWarmupState]]: """I don't like having a ton of warmup logic in here.""" if not self.needs_warmup: parameters = HMCParameters( jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) * self.parameters.num_integration_steps, jnp.ones(initial_state.position.shape[0]) * self.parameters.step_size, jnp.array([ self.parameters.inverse_mass_matrix for _ in range(initial_state.position.shape[0]) ]), ) return initial_state, parameters, None hmc_factory = jax.partial(kernel_factory, self.parameters.num_integration_steps) init, update, final = stan_hmc_warmup(hmc_factory, self.is_mass_matrix_diagonal) rng_keys = jax.random.split(rng_key, num_chains) chain_state = initial_state warmup_state = jax.vmap(init, in_axes=(0, 0, None))(rng_keys, chain_state, initial_step_size) schedule = jnp.array(stan_warmup_schedule(num_warmup_steps)) if accelerate: print( f"sampler: warmup {num_chains:,} chains for {num_warmup_steps:,} iterations.", end=" ", ) start = datetime.now() @jax.jit def update_chain(carry, interval): rng_key, chain_state, warmup_state = carry stage, is_middle_window_end = interval _, rng_key = jax.random.split(rng_key) keys = jax.random.split(rng_key, num_chains) chain_state, warmup_state, chain_info = jax.vmap( update, in_axes=(0, None, None, 0, 0))(keys, stage, is_middle_window_end, chain_state, warmup_state) return ( (rng_key, chain_state, warmup_state), (chain_state, warmup_state, chain_info), ) last_state, warmup_chain = jax.lax.scan( update_chain, (rng_key, chain_state, warmup_state), schedule) _, last_chain_state, last_warmup_state = last_state print( f"Done in {(datetime.now()-start).total_seconds():.1f} seconds." ) else: @jax.jit def update_fn(rng_key, interval, chain_state, warmup_state): rng_keys = jax.random.split(rng_key, num_chains) stage, is_middle_window_end = interval chain_state, warmup_state, chain_info = jax.vmap( update, in_axes=(0, None, None, 0, 0))(rng_keys, stage, is_middle_window_end, chain_state, warmup_state) return chain_state, warmup_state, chain_info chain = [] with tqdm(schedule, unit="samples") as progress: progress.set_description( f"Warming up {num_chains} chains for {num_warmup_steps} steps", refresh=False, ) for interval in progress: _, rng_key = jax.random.split(rng_key) chain_state, warmup_state, chain_info = update_fn( rng_key, interval, chain_state, warmup_state) chain.append((chain_state, warmup_state, chain_info)) last_chain_state, last_warmup_state, _ = chain[-1] # The sampling process, the composition between scan and for loop # is identical for the warmup and the sampling. Should we # generalize this to only call a single `scan` function? stack = lambda y, *ys: jnp.stack((y, *ys)) warmup_chain = jax.tree_multimap(stack, *chain) step_size, inverse_mass_matrix = jax.vmap( final, in_axes=(0, ))(last_warmup_state) num_integration_steps = self.parameters.num_integration_steps parameters = HMCParameters( jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) * num_integration_steps, step_size, inverse_mass_matrix, ) return last_chain_state, parameters, warmup_chain
def test_subsample_gradient(scale, subsample): data = jnp.array([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) precision = 0.06 * scale def model(subsample): with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size) as ind: x = data[ind] z = numpyro.sample("z", dist.Normal(0, 1)) numpyro.sample("x", dist.Normal(z, 1), obs=x) def guide(subsample): scale = numpyro.param("scale", 1.) with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size): loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0) numpyro.sample("z", dist.Normal(loc, scale)) if scale != 1.: model = handlers.scale(model, scale=scale) guide = handlers.scale(guide, scale=scale) num_particles = 50000 optimizer = optim.Adam(0.1) elbo = Trace_ELBO(num_particles=num_particles) svi = SVI(model, guide, optimizer, loss=elbo) svi_state = svi.init(random.PRNGKey(0), None) params = svi.optim.get_params(svi_state.optim_state) normalizer = 2 if subsample else 1 if subsample_size == 1: subsample = jnp.array([0]) loss1, grads1 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) subsample = jnp.array([1]) loss2, grads2 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) loss, grads = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) actual_loss = loss / normalizer expected_loss, _ = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))( params) assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision) actual_grads = {name: grad / normalizer for name, grad in grads.items()} expected_grads = { 'loc': scale * jnp.array([0.5, -2.0]), 'scale': scale * jnp.array([2.0]) } assert actual_grads.keys() == expected_grads.keys() for name in expected_grads: assert_allclose(actual_grads[name], expected_grads[name], rtol=precision, atol=precision)
def tree_allclose(t1, t2): t = jax.tree_multimap(jnp.allclose, t1, t2) return all(jax.tree_util.tree_flatten(t)[0])
def list_dot(x: Vector, y: Vector) -> Scalar: return np.sum( tree_multimap(lambda arr_x, arr_y: np.sum(arr_x * arr_y), x, y))
def periodic_update(new_tensors: Any, old_tensors: Any, is_time: Numeric): """Periodically switch all elements from a nested struct with new elements.""" return jax.tree_multimap( lambda new, old: jax.lax.select(is_time, new, old), new_tensors, old_tensors)
def list_add_prefactor(x: Vector, a: Scalar, y: Vector) -> Vector: # mimics x + x_1*y return tree_multimap(lambda arr_x, arr_y: arr_x + a * arr_y, x, y)
def _jvp(oks: PyTree, v: PyTree) -> Array: """ Compute the matrix-vector product between the pytree jacobian oks and the pytree vector v """ td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim) return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
def test_routing_reduce_correct(self, reduction): """Compare JAX implementations to a (slow but correct) iterative one.""" n_variants = 2 n_states = 4 def make_range_shaped(shape): return np.arange(np.prod(shape)).reshape(shape).astype("float32") schema = self.build_simple_schema() builder = automaton_builder.AutomatonBuilder(schema) routing_params = automaton_builder.RoutingParams( move=make_range_shaped([ n_variants, len(builder.in_out_route_types), n_states, n_states, ]), special=make_range_shaped([ n_variants, len(builder.in_route_types), n_states, len(builder.special_actions), ]), ) # Compute aggregates with JAX if reduction == "softmax": routing_aggregates = builder.routing_softmax(routing_params) else: routing_aggregates = builder.routing_reduce(routing_params, reduction=reduction) routing_aggregates = jax.tree_multimap( lambda s, p: np.array(jnp.broadcast_to(s, p.shape)), routing_aggregates, routing_params) # Manual looping aggregates for variant in range(n_variants): for current_state in range(n_states): for in_route_type in builder.in_route_types: # Compute aggregates distn_vals = [] iroute_idx = builder.in_route_type_to_index[in_route_type] for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): distn_vals.append( routing_params.move[variant, ioroute_idx, current_state, next_state]) for action_idx in range(len(builder.special_actions)): distn_vals.append(routing_params.special[variant, iroute_idx, current_state, action_idx]) if reduction == "sum": distn_aggregate = [sum(distn_vals)] * len(distn_vals) elif reduction == "max": distn_aggregate = [max(distn_vals)] * len(distn_vals) elif reduction == "softmax": distn_aggregate = list( jax.nn.softmax(jnp.array(distn_vals))) else: raise ValueError(f"Invalid reduction {reduction}") i = 0 # Check them with the JAX version for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): np.testing.assert_allclose( routing_aggregates.move[variant, ioroute_idx, current_state, next_state], distn_aggregate[i], rtol=1e-6) i += 1 for action_idx in range(len(builder.special_actions)): np.testing.assert_allclose( routing_aggregates.special[variant, iroute_idx, current_state, action_idx], distn_aggregate[i], rtol=1e-6) i += 1
def _sum_combinator(*args): return functools.reduce( lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args)
def train_step(state, inputs, outputs, programs, pretrain, bos_token, eos_token, learning_rate_fn, config, lp_config, train_rng=None): """Train on batch of program tasks.""" # We handle PRNG splitting inside the top pmap, rather # than handling it outside in the training loop - doing the # latter can add some stalls to the devices. train_rng, new_train_rng = jax.random.split(train_rng) weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) # Embedding mask for autoencoding. emb_mask = jnp.ones((1, FLAGS.latent_vocab_size), jnp.float32).at[:, [0, bos_token, eos_token]].set(0) def ae_loss_fn(params): """Loss function used for training autoencoder.""" (logits, vq), new_variables = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, pretrain=pretrain, mutable=['vqvae'], rngs={'dropout': train_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) # Add EOS token for latent predictor loss. vq_weight_sum = jnp.sum( jnp.where(vq['latent_indices'] > 0, 1, 0).astype(jnp.float32)) latent_indices = add_eos_token(vq['latent_indices'], eos_token) mean_loss = loss / weight_sum + vq['loss'] / vq_weight_sum return mean_loss, (new_variables['vqvae'], logits, latent_indices) step = state.step optimizer = state.optimizer lp_optimizer = state.lp_optimizer lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(ae_loss_fn, has_aux=True) (_, (new_model_state, ae_logits, latent_indices)), ae_grad = grad_fn(optimizer.target) ae_grad = jax.lax.pmean(ae_grad, 'batch') latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32) encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) # Additionally mask out eos token in latents. latents_mask = jnp.where( jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1, 0).astype(jnp.float32) def loss_fn(params, lp_params): """Loss function used for training.""" latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices, rngs={'dropout': train_rng}) latent_loss, latent_weight_sum = compute_weighted_cross_entropy( latent_logits, latent_indices, latent_weights) # End-to-end prediction. encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.decode) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) mean_loss = latent_loss / latent_weight_sum if not pretrain: mean_loss += loss / weight_sum return mean_loss, (logits, latent_logits) grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True) (_, (logits, latent_logits)), grads = grad_fn(optimizer.target, lp_optimizer.target) grads = jax.lax.pmean(grads, 'batch') new_optimizer = optimizer.apply_gradient(jax.tree_multimap( jnp.add, grads[0], ae_grad), learning_rate=lr) new_lp_optimizer = lp_optimizer.apply_gradient(grads[1], learning_rate=lr) metrics = compute_metrics(logits, programs, weights) metrics['learning_rate'] = lr metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_')) latent_metrics = compute_metrics(latent_logits, latent_indices, latent_weights, prefix='latent_') new_state = state.replace(step=step + 1, optimizer=new_optimizer, model_state=jax.lax.pmean( new_model_state, 'batch'), lp_optimizer=new_lp_optimizer) return new_state, metrics, latent_metrics, new_train_rng
def soft_update_func(old, new, tau): return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old, new)
def update(params, grads): return jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)