示例#1
0
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
    model_logp = model.logp()
    if not negative_logp:
        model_logp = -model_logp
    logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

    def logp_fn_wrap(x):
        return logp_fn(*x)[0]

    return logp_fn_wrap
示例#2
0
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
    """Compute log-likelihood for all observations"""
    data = {}
    for v in model.observed_RVs:
        v_elemwise_logp = model.logp(v, sum=False)
        jax_fn = get_jaxified_graph(inputs=model.value_vars,
                                    outputs=v_elemwise_logp)
        result = jax.jit(jax.vmap(jax.vmap(jax_fn)),
                         backend=backend)(*samples)[0]
        data[v.name] = result
    return data