def _get_log_likelihood(model: Model, samples, backend=None) -> Dict: """Compute log-likelihood for all observations""" data = {} for v in model.observed_RVs: v_elemwise_logpt = model.logpt(v, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logpt) result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=backend)(*samples)[0] data[v.name] = result return data
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: model_logpt = model.logpt() if not negative_logp: model_logpt = -model_logpt logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) def logp_fn_wrap(x): return logp_fn(*x)[0] return logp_fn_wrap