def sample(self, *args, sample_shape=(1000, )) -> jax.numpy.DeviceArray: """Return forward samples from the distribution.""" sampler, _, _, _ = compiler.compile_to_sampler(self.graph, self.namespace) _, self.rng_key = jax.random.split(self.rng_key) samples = sampler(self.rng_key, sample_shape, *args) return samples
def sampler_src(self) -> str: """Return the source code of the forward sampling funtion generated by the compiler. """ artifact = compiler.compile_to_sampler(self.graph, self.namespace) return artifact.fn_source
def forward_src(self) -> str: artifact = compiler.compile_to_sampler(self.graph, self.namespace) return artifact.fn_source
def sample_forward( rng_key: jax.random.PRNGKey, model, num_samples: int = 1, **observations ) -> Dict: """Returns forward samples from the model. Parameters ---------- rng_key Key used by JAX's random number generator. model The model from which we want to get forward samples. num_samples The number of forward samples we want to draw for each variable. **observations The values of the model's input parameters. Returns ------- A dictionary that maps all the (deterministic and random) variables defined in the model to samples from the forward prior distribution. """ keys = jax.random.split(rng_key, num_samples) sampler_args: Tuple[Any, ...] = (keys,) in_axes: Tuple[int, ...] = (0,) model_posargs = model.posargs for arg in model_posargs: try: value = observations[arg] try: sampler_args += (np.atleast_1d(value),) except RuntimeError: sampler_args += (value,) in_axes += (None,) except KeyError: raise AttributeError( "You need to specify the value of the variable {}".format(arg) ) model_kwargs = tuple(set(model.arguments).difference(model.posargs)) for kwarg in model_kwargs: if kwarg in observations: value = observations[kwarg] else: # if the kwarg value is not provided retrieve it from the graph. value = model.nodes[kwarg]["content"].default_value.n sampler_args += (value,) in_axes += (None,) out_axes: Union[int, Tuple[int, ...]] if len(model.variables) == 1: out_axes = 1 else: out_axes = (1,) * len(model.variables) artifact = compiler.compile_to_sampler(model.graph, model.namespace) sampler = jax.jit(artifact.compiled_fn) samples = jax.vmap(sampler, in_axes, out_axes)(*sampler_args) forward_trace = { arg: numpy.asarray(arg_samples).squeeze() for arg, arg_samples in zip(model.variables, samples) } return forward_trace