def __call__(self, *args, **kwargs) -> Any: """Emulates a regular function call. A `Module`'s dunder call will ensure state is updated after the function call by calling `assign` on the updated state before returning the output of the function. Args: *args: The arguments to the module. **kwargs: The keyword arguments to the module. Returns: The output of the module. """ out, new_module = self.call_and_update(*args, **kwargs) if self.name is not None: new_module = assign(new_module, name=self.name) out = primitive.tie_in(new_module, out) else: variables = { k: assign(val, name=k) for k, val in new_module.variables().items() } out = primitive.tie_in(variables, out) return out
def sow(value, *, tag: str, name: str, mode: str = 'strict', key=None): """Marks a value with a name and a tag. Args: value: A JAX value to be tagged and named. tag: a string representing the tag of the sown value. name: a string representing the name to sow the value with. mode: The mode by which to sow the value. There are three options: 1. `'strict'` - if another value is sown with the same name and tag in the same context, harvest will throw an error. 2. `'clobber'` - if another is value is sown with the same name and tag, it will replace this value 3. `'append'` - sown values of the same name and tag are appended to a growing list. Append mode assumes some ordering on the values being sown defined by data-dependence. key: an optional JAX value that will be tied into the sown value. Returns: The original `value` that was passed in. """ if key is not None: value = prim.tie_in(key, value) flat_args, in_tree = tree_util.tree_flatten(value) out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree) return tree_util.tree_unflatten(in_tree, out_flat)
def _call(self, x, training=True, rng=None): info = self.info if training: if rng is None: raise ValueError('rng is required when training is True') # Using tie_in to avoid materializing constants keep = primitive.tie_in(x, random.bernoulli(rng, info.rate, x.shape)) return np.where(keep, x / info.rate, 0) else: return x
def template_build(cls, init_key, *args, name=None, **kwargs): """Instantiates layer object from RNG and layer specifications.""" if init_key is None: raise ValueError('Cannot initialize template with `None` PRNGKey.') layer_params = cls.initialize(init_key, *args, **kwargs) if init_key is not None: new_params = tree_util.tree_map(lambda x: primitive.tie_in(init_key, x), (layer_params.params, layer_params.state)) layer_params = LayerParams(params=new_params[0], state=new_params[1], info=layer_params.info) return cls.new(layer_params, name=name)
def wrapped(*args, **kwargs): latents = harvest.reap( conditional(f, input_names), tag=RANDOM_VARIABLE)(*args, **kwargs) outputs = [latents[name] for name in output_names] latents = { name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict') for name, value in latents.items() if name not in output_names } if single_output: outputs = outputs[0] return primitive.tie_in(latents, outputs)
def wrapped(*args, **kwargs): result, latents = harvest.harvest(f, tag=RANDOM_VARIABLE)(observations, *args, **kwargs) latents = { name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict') for name, value in latents.items() } return primitive.tie_in(latents, result)
def step(key, state, init_key=None): transition_key, accept_key = random.split(key) next_state = st.init(inner_step)(init_key, transition_key, state)(transition_key, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = next_state_log_prob - state_log_prob accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.) u = primitive.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def step(key, state, init_key=None): kernel = st.init(kernel_fn, name='kernel')(init_key, key, state) def body(carry, key): kernel, state = carry state, kernel = kernel.call_and_update(key, state) for cb in callbacks: kernel, state, _ = primitive.tie_all(kernel, state, cb(kernel, state)) return (kernel, state), state (kernel, _), states = lax.scan(body, (kernel, state), random.split(key, num_steps)) return primitive.tie_in(st.assign(kernel, name='kernel'), states)
def step(key, state): transition_key, accept_key = random.split(key) next_state = inner_step(transition_key, state) forward_transition_log_prob = ppl.log_prob(inner_step)(state, next_state) backward_transition_log_prob = ppl.log_prob(inner_step)(next_state, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = (next_state_log_prob + backward_transition_log_prob - state_log_prob - forward_transition_log_prob) accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob), 0., 1.), tag=MCMC_METRICS, name='accept_prob') u = primitive.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( jax.lax.lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys
def step(params, init_key=None): out, updates = jax.value_and_grad(objective)(params) updates = primitive.tie_in(out, update(params, updates, init_key=init_key)) return apply_updates(params, updates)
def f(x, init_key=None): y = module.variable(np.zeros(x.shape), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y