Example #1
0
def elbo(param_map, model, guide, model_args, guide_args, kwargs,
         constrain_fn):
    """
    This is the most basic implementation of the Evidence Lower Bound, which is the
    fundamental objective in Variational Inference. This implementation has various
    limitations (for example it only supports random variablbes with reparameterized
    samplers) but can be used as a template to build more sophisticated loss
    objectives.

    For more details, refer to http://pyro.ai/examples/svi_part_i.html.

    :param dict param_map: dictionary of current parameter values keyed by site
        name.
    :param model: Python callable with Pyro primitives for the model.
    :param guide: Python callable with Pyro primitives for the guide
        (recognition network).
    :param tuple model_args: arguments to the model (these can possibly vary during
        the course of fitting).
    :param tuple guide_args: arguments to the guide (these can possibly vary during
        the course of fitting).
    :param dict kwargs: static keyword arguments to the model / guide.
    :param constrain_fn: a callable that transforms unconstrained parameter values
        from the optimizer to the specified constrained domain.
    :return: negative of the Evidence Lower Bound (ELBo) to be minimized.
    """
    param_map = constrain_fn(param_map)
    guide_log_density, guide_trace = log_density(guide, guide_args, kwargs,
                                                 param_map)
    model_log_density, _ = log_density(replay(model, guide_trace), model_args,
                                       kwargs, param_map)
    # log p(z) - log q(z)
    elbo = model_log_density - guide_log_density
    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    return -elbo
Example #2
0
 def _potential_energy(params):
     params_constrained = transform_fn(inv_transforms, params)
     log_joint, model_trace = log_density(model, model_args, model_kwargs,
                                          params_constrained)
     for name, t in inv_transforms.items():
         t_log_det = np.sum(
             t.log_abs_det_jacobian(params[name], params_constrained[name]))
         if 'scale' in model_trace[name]:
             t_log_det = model_trace[name]['scale'] * t_log_det
         log_joint = log_joint + t_log_det
     return -log_joint
Example #3
0
def test_scale(use_context_manager):
    def model(data):
        x = numpyro.sample('x', dist.Normal(0, 1))
        with optional(use_context_manager, handlers.scale(scale_factor=10)):
            numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    model = model if use_context_manager else handlers.scale(model, 10.)
    data = random.normal(random.PRNGKey(0), (3, ))
    x = random.normal(random.PRNGKey(1))
    log_joint = log_density(model, (data, ), {}, {'x': x})[0]
    log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(
        x, 1).log_prob(data).sum()
    expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (
        log_prob1 + log_prob2)
    assert_allclose(log_joint, expected)