예제 #1
0
    def __call__(self, *args, **kwargs) -> Any:
        """Emulates a regular function call.

    A `Module`'s dunder call will ensure state is updated after the function
    call by calling `assign` on the updated state before returning the output of
    the function.

    Args:
      *args: The arguments to the module.
      **kwargs: The keyword arguments to the module.

    Returns:
      The output of the module.
    """
        out, new_module = self.call_and_update(*args, **kwargs)
        if self.name is not None:
            new_module = assign(new_module, name=self.name)
            out = primitive.tie_in(new_module, out)
        else:
            variables = {
                k: assign(val, name=k)
                for k, val in new_module.variables().items()
            }
            out = primitive.tie_in(variables, out)
        return out
예제 #2
0
def sow(value, *, tag: str, name: str, mode: str = 'strict', key=None):
    """Marks a value with a name and a tag.

  Args:
    value: A JAX value to be tagged and named.
    tag: a string representing the tag of the sown value.
    name: a string representing the name to sow the value with.
    mode: The mode by which to sow the value. There are three options: 1.
      `'strict'` - if another value is sown with the same name and tag in the
      same context, harvest will throw an error. 2. `'clobber'` - if another is
      value is sown with the same name and tag, it will replace this value 3.
      `'append'` - sown values of the same name and tag are appended to a
      growing list. Append mode assumes some ordering on the values being sown
      defined by data-dependence.
    key: an optional JAX value that will be tied into the sown value.

  Returns:
    The original `value` that was passed in.
  """
    if key is not None:
        value = prim.tie_in(key, value)
    flat_args, in_tree = tree_util.tree_flatten(value)
    out_flat = sow_p.bind(*flat_args,
                          name=name,
                          tag=tag,
                          mode=mode,
                          tree=in_tree)
    return tree_util.tree_unflatten(in_tree, out_flat)
예제 #3
0
 def _call(self, x, training=True, rng=None):
   info = self.info
   if training:
     if rng is None:
       raise ValueError('rng is required when training is True')
     # Using tie_in to avoid materializing constants
     keep = primitive.tie_in(x, random.bernoulli(rng, info.rate, x.shape))
     return np.where(keep, x / info.rate, 0)
   else:
     return x
예제 #4
0
def template_build(cls, init_key, *args, name=None, **kwargs):
  """Instantiates layer object from RNG and layer specifications."""
  if init_key is None:
    raise ValueError('Cannot initialize template with `None` PRNGKey.')
  layer_params = cls.initialize(init_key, *args, **kwargs)
  if init_key is not None:
    new_params = tree_util.tree_map(lambda x: primitive.tie_in(init_key, x),
                                    (layer_params.params, layer_params.state))
    layer_params = LayerParams(params=new_params[0], state=new_params[1],
                               info=layer_params.info)
  return cls.new(layer_params, name=name)
예제 #5
0
 def wrapped(*args, **kwargs):
   latents = harvest.reap(
       conditional(f, input_names), tag=RANDOM_VARIABLE)(*args, **kwargs)
   outputs = [latents[name] for name in output_names]
   latents = {
       name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict')
       for name, value in latents.items()
       if name not in output_names
   }
   if single_output:
     outputs = outputs[0]
   return primitive.tie_in(latents, outputs)
예제 #6
0
 def wrapped(*args, **kwargs):
     result, latents = harvest.harvest(f,
                                       tag=RANDOM_VARIABLE)(observations,
                                                            *args, **kwargs)
     latents = {
         name: harvest.sow(value,
                           tag=RANDOM_VARIABLE,
                           name=name,
                           mode='strict')
         for name, value in latents.items()
     }
     return primitive.tie_in(latents, result)
예제 #7
0
 def step(key, state, init_key=None):
     transition_key, accept_key = random.split(key)
     next_state = st.init(inner_step)(init_key, transition_key,
                                      state)(transition_key, state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = next_state_log_prob - state_log_prob
     accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.)
     u = primitive.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
예제 #8
0
    def step(key, state, init_key=None):
        kernel = st.init(kernel_fn, name='kernel')(init_key, key, state)

        def body(carry, key):
            kernel, state = carry
            state, kernel = kernel.call_and_update(key, state)
            for cb in callbacks:
                kernel, state, _ = primitive.tie_all(kernel, state,
                                                     cb(kernel, state))
            return (kernel, state), state

        (kernel, _), states = lax.scan(body, (kernel, state),
                                       random.split(key, num_steps))
        return primitive.tie_in(st.assign(kernel, name='kernel'), states)
예제 #9
0
 def step(key, state):
     transition_key, accept_key = random.split(key)
     next_state = inner_step(transition_key, state)
     forward_transition_log_prob = ppl.log_prob(inner_step)(state,
                                                            next_state)
     backward_transition_log_prob = ppl.log_prob(inner_step)(next_state,
                                                             state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = (next_state_log_prob +
                                  backward_transition_log_prob -
                                  state_log_prob -
                                  forward_transition_log_prob)
     accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob),
                                       0., 1.),
                               tag=MCMC_METRICS,
                               name='accept_prob')
     u = primitive.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
예제 #10
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        jax.lax.lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys
예제 #11
0
 def step(params, init_key=None):
     out, updates = jax.value_and_grad(objective)(params)
     updates = primitive.tie_in(out,
                                update(params, updates, init_key=init_key))
     return apply_updates(params, updates)
예제 #12
0
 def f(x, init_key=None):
     y = module.variable(np.zeros(x.shape), name='y', key=init_key)
     next_y = module.assign(y + 1., name='y')
     return primitive.tie_in(next_y, x) + y