def sample(self, *args, sample_shape=(1000, )) -> jax.numpy.DeviceArray: """Return forward samples from the distribution. """ sampler, _, _, _ = core.compile_to_sampler(self.graph, self.namespace) _, self.rng_key = jax.random.split(self.rng_key) samples = sampler(self.rng_key, sample_shape, *args) return samples
def sample_forward(rng_key, model: model, num_samples=1, **kwargs) -> Dict: """Returns forward samples from the model. The samples are returned in a dictionary, with the names of the variables as keys. """ model_posargs = model.posargs model_kwargs = tuple(set(model.arguments).difference(model.posargs)) keys = jax.random.split(rng_key, num_samples) sampler_args: Tuple[Any, ...] = (keys, ) in_axes: Tuple[int, ...] = (0, ) for arg in model_posargs: try: value = kwargs[arg] if isinstance(value, jax.numpy.DeviceArray): idx = jax.random.randint(rng_key, (num_samples, ), 0, value.shape[0]) sampler_args += (value[idx], ) in_axes += (0, ) else: sampler_args += (value, ) in_axes += (None, ) except KeyError: raise AttributeError( f"You need to specify the value of the variable {arg}") for kwarg in model_kwargs: if kwarg in kwargs: value = kwargs[kwarg] else: value = model.nodes[kwarg]["content"].default_value.n sampler_args += (value, ) in_axes += (None, ) out_axes = (0, ) * len(model.variables) artifact = core.compile_to_sampler(model.graph, model.namespace) sampler_fn = artifact.compiled_fn sampler_fn = jax.jit(sampler_fn) samples = jax.vmap(sampler_fn, in_axes=in_axes, out_axes=out_axes)(*sampler_args) trace = { arg: numpy.asarray(arg_samples).T.squeeze() for arg, arg_samples in zip(model.variables, samples) } return trace
def sampler_src(self) -> str: """Return the source code of the forward sampling funtion generated by the compiler. """ artifact = core.compile_to_sampler(self.graph, self.namespace) return artifact.fn_source