Example #1
0
def random_variable(
    obj,
    *,
    name: Optional[str] = None,
    plate: Optional[str] = None
) -> Program:  # pylint: disable=redefined-outer-name
    """A single-dispatch function used to tag values and the outputs of programs.

  `random_variable` is a single-dispatch function that enables registering
  custom types. Its default implementation is to tag input value with a name
  and return it.

  Args:
    obj: A JAX type to be tagged.
    name (str): A string name to tag input value, cannot be `None`.
    plate (str): A string named axis for this random variable's plate.

  Returns:
    The input value.
  """
    if name is None:
        raise ValueError(f'Cannot call `random_variable` on {type(obj)} '
                         'without passing in a name.')
    if plate is not None:
        raise ValueError(f'Cannot call `random_variable` on {type(obj)} '
                         'with a plate.')
    return harvest.sow(obj, tag=RANDOM_VARIABLE, name=name, mode='strict')
Example #2
0
def summary(value, *, name: str, mode: str = 'strict'):
    """Tags a value as a summary.

  Args:
    value: a JAX value to be tagged.
    name: a string name for the tagged value.
    mode: the harvest mode for the tagged value.
  Returns:
    The original value.
  """
    return harvest.sow(value, tag=SUMMARY, name=name, mode=mode)
Example #3
0
 def wrapped(*args, **kwargs):
   latents = harvest.reap(
       conditional(f, input_names), tag=RANDOM_VARIABLE)(*args, **kwargs)
   outputs = [latents[name] for name in output_names]
   latents = {
       name: harvest.sow(value, tag=RANDOM_VARIABLE, name=name, mode='strict')
       for name, value in latents.items()
       if name not in output_names
   }
   if single_output:
     outputs = outputs[0]
   return primitive.tie_in(latents, outputs)
Example #4
0
 def wrapped(*args, **kwargs):
     result, latents = harvest.harvest(f,
                                       tag=RANDOM_VARIABLE)(observations,
                                                            *args, **kwargs)
     latents = {
         name: harvest.sow(value,
                           tag=RANDOM_VARIABLE,
                           name=name,
                           mode='strict')
         for name, value in latents.items()
     }
     return primitive.tie_in(latents, result)
Example #5
0
 def step(key, state, init_key=None):
     transition_key, accept_key = random.split(key)
     next_state = st.init(inner_step)(init_key, transition_key,
                                      state)(transition_key, state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = next_state_log_prob - state_log_prob
     accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob),
                                       0., 1.),
                               tag=MCMC_METRICS,
                               name='accept_prob')
     u = lax.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
Example #6
0
def variable(value, *, name: str, key=None, mode: str = 'strict'):
    """Tags a value as a variable.

  `variable` should be used to initialize state in stateful functions.
  Typically, `variable` will be called with a value downstream of an
  initialization key. The `init` transformation will then pull all values tagged
  as variables in a function body and store them in a `Module`.

  Args:
    value: JAX value to be tagged as variable.
    name: string name for the value.
    key: JAX value that is used to tie in `value`.
      Default value: `None`
    mode: string name for sow mode (see `harvest` documentation).
      Default value: `'strict'`

  Returns:
    The value that was passed in.
  """
    return harvest.sow(value, tag=VARIABLE, name=name, key=key, mode=mode)
Example #7
0
def assign(value, *, name: str, key=None, mode: str = 'clobber'):
    """Assigns a value to a variable.

  In a stateful function, `assign` is used define state updates. In particular,
  when a function with an `assign` is transformed using `init`, it returns a
  Module whose `call_and_update` returns the values tagged as `assign` as its
  second output. `init` requires that an assigned value must have a matching
  variable (as defined by the `name`).

  Args:
    value: JAX value to be assigned.
    name: string name for the value.
    key: JAX value that is used to tie in `value`.
      Default value: `None`
    mode: string name for sow mode (see `harvest` documentation).
      Default value: `'clobber'`

  Returns:
    The value that was passed in.
  """
    return harvest.sow(value, tag=ASSIGN, name=name, key=key, mode=mode)
 def f(x, y):
     return x, harvest.sow(x, name='x', tag='foo') * y
 def f(x):
     return harvest.sow(x, name='x', tag='foo')
Example #10
0
    def handle_call_primitive(self, call_primitive, f, tracers, params,
                              is_map):
        """Handler for call_primitives, like jit or layer_call.

    When an UnzipTracer hits a call primitive, there is either a variable
    inside of the call primitive, in which case the input
    function needs to be unzipped into two, or there are no variables
    in the function, so the call_primitive is recorded in the trace as-is.

    We use `unzip_eval_wrapper`, which returns whether or not an unzip
    was successful or not. If it was successful, we record two new
    Jaxprs into the trace (one for init, one for apply). Otherwise, we
    just record the Jaxpr corresponding to the function call.

    Args:
      call_primitive: a call primitive like xla_call
      f: a jax.linear_util wrapped function to be called
      tracers: inputs to the function
      params: parameters of the primitives
      is_map: whether or not the primitive is a map primitive (e.g. xla_pmap)

    Returns:
      A list of output tracers
    """
        name = params.get('name', f.__name__)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const_abstracted, tracers)
        if call_primitive in current_custom_rules():
            return current_custom_rules()[call_primitive](self, f, *tracers,
                                                          **params)
        if call_primitive in pe.call_partial_eval_rules:
            raise NotImplementedError
        in_pvals = [t.pval for t in tracers]
        if is_map:
            unknown = pe.PartialVal.unknown
            in_pvals = [
                pval if pval.is_known() or in_axis is None else unknown(
                    mapped_aval(params['axis_size'], in_axis, pval[0]))
                for pval, in_axis in zip(in_pvals, params['in_axes'])
            ]
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                _, num_outputs, _ = aux()
                return (0, ) * num_outputs

            new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
        else:
            new_params = params
        pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
        keys = tuple(t.is_key() for t in tracers)
        new_settings = UnzipSettings(settings.tag, call_primitive
                                     in block_registry)
        fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
        out_flat = call_primitive.bind(fun, *in_consts, **new_params)
        success, _, results = aux()
        if not success:
            out_pvs, out_keys, jaxpr, env = results
            out_pv_consts, consts = jax_util.split_list(
                out_flat, [len(out_pvs)])
            out_tracers = self._bound_output_tracers(call_primitive,
                                                     new_params, jaxpr, consts,
                                                     env, tracers, out_pvs,
                                                     out_pv_consts, out_keys,
                                                     name, is_map)
            return out_tracers
        init_name = jax_util.wrap_name(name, 'init')
        apply_name = jax_util.wrap_name(name, 'apply')
        init_pvs, num_init_consts, apply_pvs = results[0]
        init_jaxpr, apply_jaxpr = results[1]
        init_env, apply_env = results[2]
        variable_names, variable_tree, apply_keys = results[3]

        key_tracers = [t for t in tracers if t.is_key()]
        abstract_tracers = [t for t in tracers if not t.is_key()]
        all_init_consts, all_apply_consts = jax_util.split_list(
            out_flat, [len(init_pvs) + num_init_consts])
        init_pv_consts, init_consts = jax_util.split_list(
            all_init_consts, [len(init_pvs)])
        apply_pv_consts, apply_consts = jax_util.split_list(
            all_apply_consts, [len(apply_pvs)])

        variable_tracers = self._bound_output_tracers(
            call_primitive, new_params, init_jaxpr, init_consts, init_env,
            key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
            init_name, is_map)

        unflat_variables = tree_util.tree_unflatten(variable_tree,
                                                    variable_tracers)
        if call_primitive is harvest.nest_p:
            variable_dict = harvest.sow(dict(
                safe_zip(variable_names, unflat_variables)),
                                        tag=settings.tag,
                                        name=new_params['scope'],
                                        mode='strict')
            unflat_variables = tuple(variable_dict[name]
                                     for name in variable_names)
        else:
            unflat_variables = [
                harvest.sow(  # pylint: disable=g-complex-comprehension
                    unflat_variable,
                    tag=settings.tag,
                    name=name,
                    mode='strict') for unflat_variable, name in safe_zip(
                        unflat_variables, variable_names)
            ]
        variable_tracers = tree_util.tree_leaves(unflat_variables)

        out_tracers = self._bound_output_tracers(
            call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
            variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
            apply_keys, apply_name, is_map)
        return out_tracers