def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "StructuredReparam does not support observe statements") if name not in self.deltas: # On first sample site. with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) self.deltas = self.guide.get_deltas() new_fn = self.deltas.pop(name) value = new_fn.v if poutine.get_mask() is not False: log_density = new_fn.log_density + fn.log_prob(value) new_fn = dist.Delta(value, log_density, new_fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "NeuTraReparam does not support observe statements.") log_density = 0.0 compute_density = poutine.get_mask() is not False if name not in self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) z_unconstrained = pyro.sample( f"{name}_shared_latent", self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = { site["name"]: (site, unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent( x_unconstrained) } # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" event_dim = fn.event_dim transform = self.transform with ExitStack() as stack: shift = max(0, transform.event_dim - event_dim) if shift: if not self.experimental_allow_batch: raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " "setting experimental_allow_batch=True.") # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) old_shape = fn.batch_shape new_shape = old_shape[:-shift] + ( 1, ) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) transform = reshape_transform_batch(transform, old_shape + fn.event_shape, new_shape + fn.event_shape) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Draw noise from the base distribution. transform = ComposeTransform( [biject_to(fn.support).inv.with_cache(), self.transform]) x_trans = pyro.sample("{}_{}".format(name, self.suffix), dist.TransformedDistribution(fn, transform)) # Differentiably transform. x = transform.inv(x_trans) # should be free due to transform cache if shift: x = x.reshape(x.shape[:-2 * shift - event_dim] + x.shape[-shift - event_dim:]) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=event_dim) return new_fn, x
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] event_dim = fn.event_dim transform = self.transform with ExitStack() as stack: shift = max(0, transform.event_dim - event_dim) if shift: if not self.experimental_allow_batch: raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " "setting experimental_allow_batch=True.") # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) old_shape = fn.batch_shape new_shape = old_shape[:-shift] + ( 1, ) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) transform = reshape_transform_batch(transform, old_shape + fn.event_shape, new_shape + fn.event_shape) if value is not None: value = value.reshape(value.shape[:-shift - event_dim] + (1, ) * shift + value.shape[-shift - event_dim:]) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Differentiably invert transform. transform = ComposeTransform( [biject_to(fn.support).inv.with_cache(), self.transform]) value_trans = None if value is not None: value_trans = transform(value) # Draw noise from the base distribution. value_trans = pyro.sample( f"{name}_{self.suffix}", dist.TransformedDistribution(fn, transform), obs=value_trans, infer={"is_observed": is_observed}, ) # Differentiably transform. This should be free due to transform cache. if value is None: value = transform.inv(value_trans) if shift: value = value.reshape(value.shape[:-2 * shift - event_dim] + value.shape[-shift - event_dim:]) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim) return {"fn": new_fn, "value": value, "is_observed": True}