def elbo(param_map, model, guide, model_args, guide_args, kwargs, constrain_fn): """ This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variablbes with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives. For more details, refer to http://pyro.ai/examples/svi_part_i.html. :param dict param_map: dictionary of current parameter values keyed by site name. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict kwargs: static keyword arguments to the model / guide. :param constrain_fn: a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain. :return: negative of the Evidence Lower Bound (ELBo) to be minimized. """ param_map = constrain_fn(param_map) guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. return -elbo
def _potential_energy(params): params_constrained = transform_fn(inv_transforms, params) log_joint, model_trace = log_density(model, model_args, model_kwargs, params_constrained) for name, t in inv_transforms.items(): t_log_det = np.sum( t.log_abs_det_jacobian(params[name], params_constrained[name])) if 'scale' in model_trace[name]: t_log_det = model_trace[name]['scale'] * t_log_det log_joint = log_joint + t_log_det return -log_joint
def test_scale(use_context_manager): def model(data): x = numpyro.sample('x', dist.Normal(0, 1)) with optional(use_context_manager, handlers.scale(scale_factor=10)): numpyro.sample('obs', dist.Normal(x, 1), obs=data) model = model if use_context_manager else handlers.scale(model, 10.) data = random.normal(random.PRNGKey(0), (3, )) x = random.normal(random.PRNGKey(1)) log_joint = log_density(model, (data, ), {}, {'x': x})[0] log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal( x, 1).log_prob(data).sum() expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * ( log_prob1 + log_prob2) assert_allclose(log_joint, expected)