def test_single_chain(self, make_kernel, target_accept_rate):
        num_samples = 20000

        sample_key, chain_key, init_key = random.split(self._seed, 3)
        unconstrained_log_prob = self._make_unconstrained_log_prob()
        initial_state = self._initialize_state(init_key)
        kernel = make_kernel(unconstrained_log_prob)
        sample_chain = jax.jit(
            harvest.harvest(kernels.sample_chain(kernel, num_samples),
                            tag=kernels.MCMC_METRICS))

        true_samples = self.model.sample(sample_shape=4096, seed=sample_key)
        samples, metrics = sample_chain({}, chain_key, initial_state)

        onp.testing.assert_allclose(true_samples.mean(axis=0),
                                    samples.mean(axis=0),
                                    rtol=0.5,
                                    atol=0.1)
        onp.testing.assert_allclose(np.cov(true_samples.T),
                                    np.cov(samples.T),
                                    rtol=0.5,
                                    atol=0.1)
        onp.testing.assert_allclose(target_accept_rate,
                                    metrics['kernel']['accept_prob'].mean(),
                                    atol=1e-2,
                                    rtol=1e-2)
    def test_multiple_chains(self, make_kernel, target_accept_rate):
        num_chains = 16
        num_samples = 4000

        sample_key, chain_key, init_key = random.split(self._seed, 3)
        unconstrained_log_prob = self._make_unconstrained_log_prob()
        initial_states = jax.vmap(self._initialize_state)(random.split(
            init_key, num_chains))

        kernel = make_kernel(unconstrained_log_prob)
        sample_chain = jax.jit(
            jax.vmap(
                harvest.harvest(kernels.sample_chain(kernel, num_samples),
                                tag=kernels.MCMC_METRICS)))

        true_samples = self.model.sample(sample_shape=4096, seed=sample_key)
        samples, metrics = sample_chain({},
                                        random.split(chain_key, num_chains),
                                        initial_states)
        samples = tf.nest.map_structure(
            lambda s, shape: s.reshape([num_chains * num_samples] + list(shape)
                                       ), samples, self.model.event_shape)

        onp.testing.assert_allclose(true_samples.mean(axis=0),
                                    samples.mean(axis=0),
                                    rtol=0.1,
                                    atol=0.1)
        onp.testing.assert_allclose(np.cov(true_samples.T),
                                    np.cov(samples.T),
                                    rtol=0.1,
                                    atol=0.1)
        onp.testing.assert_allclose(target_accept_rate,
                                    metrics['kernel']['accept_prob'].mean(),
                                    atol=1e-2,
                                    rtol=1e-2)
Пример #3
0
def get_summaries(f):
    """Transforms a function into one that additionally output summaries.

  Args:
    f: a callable.
  Returns:
    A function that when called returns the original output of `f` and a
    dictionary mapping summary names to their values during execution.
  """
    return functools.partial(harvest.harvest(f, tag=SUMMARY), {})
Пример #4
0
 def wrapped(*args, **kwargs):
     result, latents = harvest.harvest(f,
                                       tag=RANDOM_VARIABLE)(observations,
                                                            *args, **kwargs)
     latents = {
         name: harvest.sow(value,
                           tag=RANDOM_VARIABLE,
                           name=name,
                           mode='strict')
         for name, value in latents.items()
     }
     return primitive.tie_in(latents, result)
Пример #5
0
    def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
        has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
        if not has_init_key:

            def init_f(init_key, *args, **kwargs):
                del init_key, args, kwargs
                return {}

            def cau_f(variables, *args, **kwargs):
                return f(*args, **kwargs), variables
        else:
            with unzip.new_custom_rules(custom_unzip_rules):

                def fun(init_key, *args, **kwargs):
                    kwargs = {**kwargs, init_keyword: init_key}
                    return f(*args, **kwargs)

                init_f, apply_f = unzip.unzip(fun,
                                              tag=module.VARIABLE)(init_key,
                                                                   *args,
                                                                   **kwargs)
            cau_f = functools.partial(
                harvest.harvest(apply_f, tag=module.ASSIGN), {})
        if name is not None:
            init_f = harvest.nest(init_f, scope=name)
            cau_f = harvest.nest(cau_f, scope=name)
        variables = init_f(init_key)
        cau_jaxpr, (in_tree,
                    out_tree) = trace_util.stage(cau_f,
                                                 dynamic=True)(variables,
                                                               *args, **kwargs)
        if name is None:
            variables = {
                k: module.variable(val, name=k, key=init_key)
                for k, val in variables.items()
            }
            return FunctionModule(variables,
                                  cau_jaxpr,
                                  in_tree,
                                  out_tree,
                                  name=name)
        else:
            return module.variable(FunctionModule(variables,
                                                  cau_jaxpr,
                                                  in_tree,
                                                  out_tree,
                                                  name=name),
                                   name=name,
                                   key=init_key)
Пример #6
0
  def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
    has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
    if not has_init_key:

      def init_f(init_key, *args, **kwargs):
        del init_key, args, kwargs
        return {}

      def cau_f(variables, *args, **kwargs):
        return f(*args, **kwargs), variables
    else:

      def f_(init_key, *args, **kwargs):
        return f(*args, **kwargs, init_key=init_key)

      def init_f(init_key, *args, **kwargs):
        return harvest.reap(
            f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs)

      def apply_f(variables, *args, **kwargs):
        return harvest.plant(
            f_, tag=module.VARIABLE)(variables, random.PRNGKey(0), *args,
                                     **kwargs)

      cau_f = functools.partial(harvest.harvest(apply_f, tag=module.ASSIGN), {})
    variables = init_f(init_key, *args, **kwargs)
    cau_jaxpr, (in_tree, out_tree) = trace_util.stage(
        cau_f, dynamic=True)(variables, *args, **kwargs)
    if name is None:
      variables = {
          k: module.variable(val, name=k, key=init_key)
          for k, val in variables.items()
      }
      return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
    else:
      mod = FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
      if variables:
        return module.variable(mod, name=name, key=init_key)
      return mod
Пример #7
0
 def wrapped(*args, **kwargs):
     return harvest.harvest(f, tag=RANDOM_VARIABLE,
                            mode='plant_only')(observations, *args,
                                               **kwargs)[0]
Пример #8
0
 def program(key, *args, **kwargs):
     return harvest.harvest(
         f,
         tag=RANDOM_VARIABLE,
         blocklist=names,
         mode=harvest.HarvestMode.PLANT_ONLY)({}, key, *args, **kwargs)[0]