def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) def logp_fn_wrap(x): return logp_fn(*x)[0] return logp_fn_wrap
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict: """Compute log-likelihood for all observations""" data = {} for v in model.observed_RVs: v_elemwise_logp = model.logp(v, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logp) result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=backend)(*samples)[0] data[v.name] = result return data