コード例 #1
0
ファイル: util.py プロジェクト: jatentaki/numpyro
def potential_energy(model, model_args, model_kwargs, params, enum=False):
    """
    (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params.
    Under the hood, we will transform these unconstrained parameters to the values
    belong to the supports of the corresponding priors in `model`.

    :param model: a callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: unconstrained parameters of `model`.
    :param bool enum: whether to enumerate over discrete latent sites.
    :return: potential energy given unconstrained parameters.
    """
    if enum:
        from numpyro.contrib.funsor import log_density as log_density_
    else:
        log_density_ = log_density

    substituted_model = substitute(model,
                                   substitute_fn=partial(
                                       _unconstrain_reparam, params))
    # no param is needed for log_density computation because we already substitute
    log_joint, model_trace = log_density_(substituted_model, model_args,
                                          model_kwargs, {})
    return -log_joint
コード例 #2
0
 def loglik_fn(**params):
     return log_density_(reparam_model, args, kwargs, params)[0]