def potential_energy(model, model_args, model_kwargs, params, enum=False): """ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in `model`. :param model: a callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: unconstrained parameters of `model`. :param bool enum: whether to enumerate over discrete latent sites. :return: potential energy given unconstrained parameters. """ if enum: from numpyro.contrib.funsor import log_density as log_density_ else: log_density_ = log_density substituted_model = substitute(model, substitute_fn=partial( _unconstrain_reparam, params)) # no param is needed for log_density computation because we already substitute log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {}) return -log_joint
def loglik_fn(**params): return log_density_(reparam_model, args, kwargs, params)[0]