Esempio n. 1
0
 def sample(self, *args, sample_shape=(1000, )) -> jax.numpy.DeviceArray:
     """Return forward samples from the distribution."""
     sampler, _, _, _ = compiler.compile_to_sampler(self.graph,
                                                    self.namespace)
     _, self.rng_key = jax.random.split(self.rng_key)
     samples = sampler(self.rng_key, sample_shape, *args)
     return samples
Esempio n. 2
0
 def sampler_src(self) -> str:
     """Return the source code of the forward sampling funtion
     generated by the compiler.
     """
     artifact = compiler.compile_to_sampler(self.graph, self.namespace)
     return artifact.fn_source
Esempio n. 3
0
 def forward_src(self) -> str:
     artifact = compiler.compile_to_sampler(self.graph, self.namespace)
     return artifact.fn_source
Esempio n. 4
0
def sample_forward(
    rng_key: jax.random.PRNGKey, model, num_samples: int = 1, **observations
) -> Dict:
    """Returns forward samples from the model.

    Parameters
    ----------
    rng_key
        Key used by JAX's random number generator.
    model
        The model from which we want to get forward samples.
    num_samples
        The number of forward samples we want to draw for each variable.
    **observations
        The values of the model's input parameters.


    Returns
    -------
    A dictionary that maps all the (deterministic and random) variables defined
    in the model to samples from the forward prior distribution.

    """

    keys = jax.random.split(rng_key, num_samples)
    sampler_args: Tuple[Any, ...] = (keys,)
    in_axes: Tuple[int, ...] = (0,)

    model_posargs = model.posargs
    for arg in model_posargs:
        try:
            value = observations[arg]
            try:
                sampler_args += (np.atleast_1d(value),)
            except RuntimeError:
                sampler_args += (value,)
            in_axes += (None,)
        except KeyError:
            raise AttributeError(
                "You need to specify the value of the variable {}".format(arg)
            )

    model_kwargs = tuple(set(model.arguments).difference(model.posargs))
    for kwarg in model_kwargs:
        if kwarg in observations:
            value = observations[kwarg]
        else:
            # if the kwarg value is not provided retrieve it from the graph.
            value = model.nodes[kwarg]["content"].default_value.n
        sampler_args += (value,)
        in_axes += (None,)

    out_axes: Union[int, Tuple[int, ...]]
    if len(model.variables) == 1:
        out_axes = 1
    else:
        out_axes = (1,) * len(model.variables)

    artifact = compiler.compile_to_sampler(model.graph, model.namespace)
    sampler = jax.jit(artifact.compiled_fn)
    samples = jax.vmap(sampler, in_axes, out_axes)(*sampler_args)

    forward_trace = {
        arg: numpy.asarray(arg_samples).squeeze()
        for arg, arg_samples in zip(model.variables, samples)
    }

    return forward_trace