def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): return name, fn, obs event_shape = fn.event_shape fn, batch_shape, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = numpyro.param("{}_centered".format(name), jnp.full(event_shape, 0.5), constraint=constraints.unit_interval) params["loc"] = fn.loc * centered params["scale"] = fn.scale**centered decentered_fn = self._wrap(type(fn)(**params), batch_shape, event_dim) # Draw decentered noise. decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta # Simulate a pyro.deterministic() site. return None, value
def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.zeros(()) for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
def get_importance_trace(model, guide, args, kwargs, params): """ (EXPERIMENTAL) Returns traces from the guide and the model that is run against it. The returned traces also store the log probability at each site. .. note:: Gradients are blocked at latent sites which do not have reparametrized samplers. """ guide = substitute(guide, data=params) with _without_rsample_stop_gradient(): guide_trace = trace(guide).get_trace(*args, **kwargs) model = substitute(replay(model, guide_trace), data=params) model_trace = trace(model).get_trace(*args, **kwargs) for tr in (guide_trace, model_trace): for site in tr.values(): if site["type"] == "sample": if "log_prob" not in site: value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob site["log_prob"] = log_prob return model_trace, guide_trace
def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity): value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
def log_prob_sum(trace): log_joint = jnp.zeros(()) for site in trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint
def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.zeros(()) for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: guide_shape = jnp.shape(value) model_shape = tuple(site["fn"].shape( )) # TensorShape from tfp needs casting to tuple try: broadcast_shapes(guide_shape, model_shape) except ValueError: raise ValueError( "Model and guide shapes disagree at site: '{}': {} vs {}" .format(site["name"], model_shape, guide_shape)) log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
def terms_from_trace(tr): """Helper function to extract elbo components from execution traces.""" log_factors = {} log_measures = {} sum_vars, prod_vars = frozenset(), frozenset() for site in tr.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob_factor = funsor.to_funsor(log_prob, output=funsor.Real, dim_to_name=dim_to_name) if site["is_observed"]: log_factors[site["name"]] = log_prob_factor else: log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) prod_vars |= frozenset(f.name for f in site["cond_indep_stack"] if f.dim is not None) return { "log_factors": log_factors, "log_measures": log_measures, "measure_vars": sum_vars, "plate_vars": prod_vars, }
def log_density(model, model_args, model_kwargs, params): """ Similar to :func:`numpyro.infer.util.log_density` but works for models with discrete latent variables. Internally, this uses :mod:`funsor` to marginalize discrete latent sites and evaluate the joint log probability. :param model: Python callable containing NumPyro primitives. Typically, the model has been enumerated by using :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # _init/... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob = funsor.to_funsor(log_prob, output=funsor.reals(), dim_to_name=dim_to_name) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable( name, funsor.domains.bint(site["value"].shape[dim])) time_to_factors[time_dim].append(log_prob) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_init")) break if time_dim is None: log_factors.append(log_prob) if not site['is_observed']: sum_vars |= frozenset({site['name']}) prod_vars |= frozenset(f.name for f in site['cond_indep_stack'] if f.dim is not None) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = "/".join(var.split("/")[1:]) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values( ): # i.e. _init (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values()) if len(time_to_factors) > 0: markov_factors = compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars) log_factors = log_factors + markov_factors with funsor.interpreter.interpretation(funsor.terms.lazy): lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.". format(result.data.shape, {k.split("__BOUND")[0] for k in result.inputs})) return result.data, model_trace
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # PP... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() history = 1 log_measures = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob_factor = funsor.to_funsor( log_prob, output=funsor.Real, dim_to_name=dim_to_name ) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]]) time_to_factors[time_dim].append(log_prob_factor) history = max( history, max(_get_shift(s) for s in dim_to_name.values()) ) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_PREV_") ) break if time_dim is None: log_factors.append(log_prob_factor) if not site["is_observed"]: log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) prod_vars |= frozenset( f.name for f in site["cond_indep_stack"] if f.dim is not None ) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = _shift_name(var, -_get_shift(var)) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values(): # i.e. _PREV_* (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values() ) if len(time_to_factors) > 0: markov_factors = compute_markov_factors( time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars, history, sum_op, prod_op, ) log_factors = log_factors + markov_factors with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( sum_op, prod_op, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars, ) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.".format( result.data.shape, {k.split("__BOUND")[0] for k in result.inputs} ) ) return result, model_trace, log_measures