def model(): x = pyro.sample("x", dist.Normal(0, 1)) pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0)) called.add("model-always") if poutine.get_mask() is not False: called.add("model-sometimes") pyro.factor("f", x + 1)
def _transform_values( self, aux_values: Dict[str, torch.Tensor], ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: # Learnably transform auxiliary values to user-facing values. values = {} log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False for name, site in self._factors.items(): if site["is_observed"]: continue loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) unconstrained = aux_values[name] * scale + loc # Transform to constrained space. transform = biject_to(site["fn"].support) values[name] = transform(unconstrained) if compute_density: assert transform.codomain.event_dim == site["fn"].event_dim log_densities[name] = transform.inv.log_abs_det_jacobian( values[name], unconstrained ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1) return values, log_densities
def __call__(self, name, fn, obs): assert fn.event_dim >= self.event_dim assert obs is None, "SplitReparam does not support observe statements" # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim:] parts = [] for i, size in enumerate(self.sections): event_shape = left_shape + (size, ) + right_shape parts.append( pyro.sample( "{}_split_{}".format(name, i), dist.ImproperUniform(fn.support, fn.batch_shape, event_shape))) value = torch.cat(parts, dim=-self.event_dim) # Combine parts. if poutine.get_mask() is False: log_density = 0.0 else: log_density = fn.log_prob(value) new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_density) return new_fn, value
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "StructuredReparam does not support observe statements") if name not in self.deltas: # On first sample site. with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) self.deltas = self.guide.get_deltas() new_fn = self.deltas.pop(name) value = new_fn.v if poutine.get_mask() is not False: log_density = new_fn.log_density + fn.log_prob(value) new_fn = dist.Delta(value, log_density, new_fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]: funsor = _import_funsor() # Convert torch to funsor. particle_plates = frozenset(get_plates()) plate_to_dim = self._funsor_plate_to_dim.copy() plate_to_dim.update({f.name: f.dim for f in particle_plates}) factors = {} for d, inputs in self._funsor_factor_inputs.items(): batch_shape = torch.Size( p.size for p in sorted(self._plates[d], key=lambda p: p.dim) ) white_vec = deep_getattr(self.white_vecs, d) prec_sqrt = deep_getattr(self.prec_sqrts, d) factors[d] = funsor.gaussian.Gaussian( white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]), prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]), inputs=inputs, ) # Perform Gaussian tensor variable elimination. if temperature == 1: samples, log_prob = _try_possibly_intractable( funsor.recipes.forward_filter_backward_rsample, factors=factors, eliminate=self._funsor_eliminate, plates=frozenset(plate_to_dim), sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates}, ) else: samples, log_prob = _try_possibly_intractable( funsor.recipes.forward_filter_backward_precondition, factors=factors, eliminate=self._funsor_eliminate, plates=frozenset(plate_to_dim), ) # Substitute noise. sample_shape = torch.Size(f.size for f in particle_plates) noise = torch.randn(sample_shape + log_prob.inputs["aux"].shape) noise.mul_(temperature) aux = funsor.Tensor(noise)[tuple(f.name for f in particle_plates)] with funsor.interpretations.memoize(): samples = {k: v(aux=aux) for k, v in samples.items()} log_prob = log_prob(aux=aux) # Convert funsor to torch. if am_i_wrapped() and poutine.get_mask() is not False: log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim) pyro.factor(f"_{self._pyro_name}_latent", log_prob, has_rsample=True) samples = { k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() } return samples
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] # Compute a guide distribution, either static or dependent. guide_dist = self.guide if not isinstance(guide_dist, dist.Distribution): args, kwargs = self.args_kwargs guide_dist = guide_dist(*args, **kwargs) assert isinstance(guide_dist, dist.Distribution) # Draw a sample from the updated distribution. fn, log_normalizer = fn.conjugate_update(guide_dist) assert isinstance(guide_dist, dist.Distribution) if not fn.has_rsample: # Note supporting non-reparameterized sites would require more delicate # handling of traced sites than the crude _do_not_trace flag below. raise NotImplementedError( "ConjugateReparam inference supports only reparameterized " "distributions, but got {}".format(type(fn)) ) value = pyro.sample( f"{name}_updated", fn, obs=value, infer={ "is_observed": is_observed, "is_auxiliary": True, "_do_not_trace": True, }, ) # Compute importance weight. Let p(z) be the original fn, q(z|x) be # the guide, and u(z) be the conjugate_updated distribution. Then # normalizer = p(z) q(z|x) / u(z). # Since we've sampled from u(z) instead of p(z), we # need an importance weight # p(z) / u(z) = normalizer / q(z|x) (Eqn 1) # Note that q(z|x) is often approximate; in the exact case # q(z|x) = p(x|z) / integral p(x|z) dz # so this site and the downstream likelihood site will have combined density # (p(z) / u(z)) p(x|z) = (normalizer / q(z|x)) p(x|z) # = normalizer integral p(x|z) dz # Hence in the exact case, downstream probability does not depend on the sampled z, # permitting this reparameterizer to be used in HMC. if poutine.get_mask() is False: log_density = 0.0 else: log_density = log_normalizer - guide_dist.log_prob(value) # By Eqn 1. # Return an importance-weighted point estimate. new_fn = dist.Delta(value, log_density=log_density, event_dim=fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): transform = biject_to(site["fn"].support) with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) site_loc, site_scale = self._get_loc_and_scale(name) unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( site_loc, site_scale, ).to_event(self._event_dims[name]), infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) if poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_latent, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) result[name] = pyro.sample(name, delta_dist) return result
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "NeuTraReparam does not support observe statements.") log_density = 0.0 compute_density = poutine.get_mask() is not False if name not in self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) z_unconstrained = pyro.sample( f"{name}_shared_latent", self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = { site["name"]: (site, unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent( x_unconstrained) } # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates(*args, **kwargs) # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) if poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_value, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def __call__(self, name, fn, obs): if name not in self.guide.prototype_trace.nodes: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0.0 compute_density = (poutine.get_mask() is not False) if not self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e z_unconstrained = pyro.sample( "{}_shared_latent".format(name), self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = list( reversed(list(self.guide._unpack_latent(x_unconstrained)))) # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop() assert name == site["name"], "model structure changed" transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return new_fn, value
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] assert fn.event_dim >= self.event_dim # Split value into parts. value_split = [None] * len(self.sections) if value is not None: value_split[:] = value.split(self.sections, -self.event_dim) # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim:] for i, size in enumerate(self.sections): event_shape = left_shape + (size, ) + right_shape value_split[i] = pyro.sample( f"{name}_split_{i}", dist.ImproperUniform(fn.support, fn.batch_shape, event_shape), obs=value_split[i], infer={"is_observed": is_observed}, ) # Combine parts into value. if value is None: value = torch.cat(value_split, dim=-self.event_dim) if poutine.get_mask() is False: log_density = 0.0 else: log_density = fn.log_prob(value) new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_density) return {"fn": new_fn, "value": value, "is_observed": True}
def guide(): x = pyro.sample("x", dist.Normal(0, 1)) called.add("guide-always") if poutine.get_mask() is not False: called.add("guide-sometimes") pyro.factor("g", 2 - x)
def get_deltas(self, save_params=None): deltas = {} aux_values = {} compute_density = poutine.get_mask() is not False for name, site in self._sorted_sites: if save_params is not None and name not in save_params: continue # Sample zero-mean blockwise independent Delta/Normal/MVN. log_density = 0.0 loc = deep_getattr(self.locs, name) zero = torch.zeros_like(loc) conditional = self.conditionals[name] if callable(conditional): aux_value = deep_getattr(self.conds, name)() elif conditional == "delta": aux_value = zero elif conditional == "normal": aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) aux_value = aux_value * scale if compute_density: log_density = (-scale.log()).expand_as(aux_value) elif conditional == "mvn": # This overparametrizes by learning (scale,scale_tril), # enabling faster learning of the more-global scale parameter. aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) scale_tril = deep_getattr(self.scale_trils, name) aux_value = aux_value @ scale_tril.T * scale if compute_density: log_density = ( -scale_tril.diagonal(dim1=-2, dim2=-1).log() - scale.log()).expand_as(aux_value) else: raise ValueError( f"Unsupported conditional type: {conditional}") # Accumulate upstream dependencies. # Note: by accumulating upstream dependencies before updating the # aux_values dict, we encode a block-sparse structure of the # precision matrix; if we had instead accumulated after updating # aux_values, we would encode a block-sparse structure of the # covariance matrix. # Note: these shear transforms have no effect on the Jacobian # determinant, and can therefore be excluded from the log_density # computation below, even for nonlinear dep(). deps = deep_getattr(self.deps, name) for upstream in self.dependencies.get(name, {}): dep = deep_getattr(deps, upstream) aux_value = aux_value + dep(aux_values[upstream]) aux_values[name] = aux_value # Shift by loc and reshape. batch_shape = torch.broadcast_shapes(aux_value.shape[:-1], self._batch_shapes[name]) unconstrained = ( aux_value + loc).reshape(batch_shape + self._unconstrained_event_shapes[name]) if not is_identically_zero(log_density): log_density = log_density.reshape(batch_shape + (-1, )).sum(-1) # Transform to constrained space. transform = biject_to(site["fn"].support) value = transform(unconstrained) if compute_density and conditional != "delta": assert transform.codomain.event_dim == site["fn"].event_dim log_density = log_density + transform.inv.log_abs_det_jacobian( value, unconstrained) # Create a reparametrized Delta distribution. deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim) return deltas