コード例 #1
0
ファイル: model.py プロジェクト: TuanNguyen27/mcx
 def sample(self, *args, sample_shape=(1000, )) -> jax.numpy.DeviceArray:
     """Return forward samples from the distribution.
     """
     sampler, _, _, _ = core.compile_to_sampler(self.graph, self.namespace)
     _, self.rng_key = jax.random.split(self.rng_key)
     samples = sampler(self.rng_key, sample_shape, *args)
     return samples
コード例 #2
0
def sample_forward(rng_key, model: model, num_samples=1, **kwargs) -> Dict:
    """Returns forward samples from the model.

    The samples are returned in a dictionary, with the names of
    the variables as keys.
    """
    model_posargs = model.posargs
    model_kwargs = tuple(set(model.arguments).difference(model.posargs))

    keys = jax.random.split(rng_key, num_samples)
    sampler_args: Tuple[Any, ...] = (keys, )
    in_axes: Tuple[int, ...] = (0, )

    for arg in model_posargs:
        try:
            value = kwargs[arg]
            if isinstance(value, jax.numpy.DeviceArray):
                idx = jax.random.randint(rng_key, (num_samples, ), 0,
                                         value.shape[0])
                sampler_args += (value[idx], )
                in_axes += (0, )
            else:
                sampler_args += (value, )
                in_axes += (None, )
        except KeyError:
            raise AttributeError(
                f"You need to specify the value of the variable {arg}")

    for kwarg in model_kwargs:
        if kwarg in kwargs:
            value = kwargs[kwarg]
        else:
            value = model.nodes[kwarg]["content"].default_value.n
        sampler_args += (value, )
        in_axes += (None, )

    out_axes = (0, ) * len(model.variables)

    artifact = core.compile_to_sampler(model.graph, model.namespace)
    sampler_fn = artifact.compiled_fn
    sampler_fn = jax.jit(sampler_fn)
    samples = jax.vmap(sampler_fn, in_axes=in_axes,
                       out_axes=out_axes)(*sampler_args)

    trace = {
        arg: numpy.asarray(arg_samples).T.squeeze()
        for arg, arg_samples in zip(model.variables, samples)
    }

    return trace
コード例 #3
0
 def sampler_src(self) -> str:
     """Return the source code of the forward sampling funtion
     generated by the compiler.
     """
     artifact = core.compile_to_sampler(self.graph, self.namespace)
     return artifact.fn_source