def autoguide(self, name, dist_constructor): """ Sets an autoguide for an existing parameter with name ``name`` (mimic the behavior of module :mod:`pyro.infer.autoguide`). .. note:: `dist_constructor` should be one of :class:`~pyro.distributions.Delta`, :class:`~pyro.distributions.Normal`, and :class:`~pyro.distributions.MultivariateNormal`. More distribution constructor will be supported in the future if needed. :param str name: Name of the parameter. :param dist_constructor: A :class:`~pyro.distributions.distribution.Distribution` constructor. """ if name not in self._priors: raise ValueError( "There is no prior for parameter: {}".format(name)) if dist_constructor not in [ dist.Delta, dist.Normal, dist.MultivariateNormal ]: raise NotImplementedError( "Unsupported distribution type: {}".format(dist_constructor)) # delete old guide if name in self._guides: dist_args = self._guides[name][1] for arg in dist_args: delattr(self, "{}_{}".format(name, arg)) p = self._priors[name]() # init_to_sample strategy if dist_constructor is dist.Delta: support = self._priors[name].support if _is_real_support(support): p_map = Parameter(p.detach()) else: p_map = PyroParam(p.detach(), support) setattr(self, "{}_map".format(name), p_map) dist_args = ("map", ) elif dist_constructor is dist.Normal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) scale = PyroParam(loc.new_ones(loc.shape), constraints.positive) setattr(self, "{}_loc".format(name), loc) setattr(self, "{}_scale".format(name), scale) dist_args = ("loc", "scale") elif dist_constructor is dist.MultivariateNormal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) identity = eye_like(loc, loc.size(-1)) scale_tril = PyroParam(identity.repeat(loc.shape[:-1] + (1, 1)), constraints.lower_cholesky) setattr(self, "{}_loc".format(name), loc) setattr(self, "{}_scale_tril".format(name), scale_tril) dist_args = ("loc", "scale_tril") else: raise NotImplementedError self._guides[name] = (dist_constructor, dist_args)
def templates_guide_mvn(self): """ Multivariate normal guide for template parameters """ loc = _deep_getattr(self, "mvn.loc") scale_tril = _deep_getattr(self, "mvn.scale_tril") dt = dist.MultivariateNormal(loc, scale_tril=scale_tril) states = pyro.sample("states_" + self.name_prefix, dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.poiss_labels[i_poiss]] = pyro.sample( self.poiss_labels[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps]] = pyro.sample( self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def guide(self): a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0)) a_scales_tril = pyro.param( "a_scales", lambda: 0.1 * eye_like(a_locs, self.n_params), constraint=constraints.lower_cholesky) dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril) states = pyro.sample("states", dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.labels_poiss[i_poiss]] = pyro.sample( self.labels_poiss[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps]] = pyro.sample( self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()} for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc)
def get_posterior( self, name: str, prior: Distribution, ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) if self._computing_quantiles: return self._get_posterior_quantiles(name, prior) if self._computing_mi: # the messenger autoguide needs the output to fit certain dimensions # this is hack which saves MI to self.mi but returns cheap to compute medians self.mi[name] = self._get_mutual_information(name, prior) return self._get_posterior_median(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) # If hierarchical_sites not specified all sites are assumed to be hierarchical if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): loc, scale, weight = self._get_params(name, prior) loc = loc + transform.inv(prior.mean) * weight posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), transform.with_cache(), ) return posterior else: # Fall back to mean field when hierarchical_sites list is not empty and site not in the list. loc, scale = self._get_params(name, prior) posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), transform.with_cache(), ) return posterior
def _initialize_model_properties(self): if self.max_plate_nesting is None: self._guess_max_plate_nesting() # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. self.model = poutine.enum(config_enumerate(self.model), first_available_dim=-1 - self.max_plate_nesting) if self._automatic_transform_enabled: self.transforms = {} trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) for name, node in trace.iter_stochastic_nodes(): if isinstance(node["fn"], _Subsample): continue if node["fn"].has_enumerate_support: self._has_enumerable_sites = True continue site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) self._r_shapes[name] = site_value.shape self._r_numels[name] = site_value.numel() self._trace_prob_evaluator = TraceEinsumEvaluator( trace, self._has_enumerable_sites, self.max_plate_nesting) mass_matrix_size = sum(self._r_numels.values()) if self.full_mass: initial_mass_matrix = eye_like(site_value, mass_matrix_size) else: initial_mass_matrix = site_value.new_ones(mass_matrix_size) self._adapter.inverse_mass_matrix = initial_mass_matrix
def test_biject_to(constraint_fn, args, is_cuda): constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) try: t = biject_to(constraint) except NotImplementedError: pytest.skip('`biject_to` not implemented.') assert t.bijective, "biject_to({}) is not bijective".format(constraint) if constraint_fn is constraints.corr_cholesky: # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) x = torch.randn(6, 6, dtype=torch.double) else: x = torch.randn(5, 5, dtype=torch.double) if is_cuda: x = x.cuda() y = t(x) assert constraint.check(y).all(), '\n'.join([ "Failed to biject_to({})".format(constraint), "x = {}".format(x), "biject_to(...)(x) = {}".format(y), ]) x2 = t.inv(y) assert torch.allclose( x, x2), "Error in biject_to({}) inverse".format(constraint) j = t.log_abs_det_jacobian(x, y) assert j.shape == x.shape[:x.dim() - t.input_event_dim]
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) iaranges = self._create_iaranges() # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(iaranges[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates() # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def get_transforms(model: Callable, *model_args: Any, **model_kwargs: Any): """Get automatic transforms to unbounded space Args: model: Pyro model model_args: Arguments passed to model model_args: Keyword arguments passed to model Example: ```python def prior(): return pyro.sample("theta", pyro.distributions.Uniform(0., 1.)) transform_to_unbounded = get_transforms(prior)["theta"] ``` """ transforms = {} model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) for name, node in model_trace.iter_stochastic_nodes(): fn = node["fn"] transforms[name] = biject_to(fn.support).inv return transforms
def test_auto_dirichlet(auto_class, Elbo): num_steps = 2000 prior = torch.tensor([0.5, 1.0, 1.5, 3.0]) data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long() posterior = torch.tensor([4.5, 3.0, 6.5, 3.0]) def model(data): p = pyro.sample("p", dist.Dirichlet(prior)) with pyro.plate("data_plate"): pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data) guide = auto_class(model) svi = SVI(model, guide, optim.Adam({"lr": 0.003}), loss=Elbo()) for _ in range(num_steps): loss = svi.step(data) assert np.isfinite(loss), loss expected_mean = posterior / posterior.sum() if isinstance(guide, (AutoIAFNormal, AutoNormalizingFlow)): loc = guide.transform(torch.zeros(guide.latent_dim)) else: loc = guide.loc actual_mean = biject_to(constraints.simplex)(loc) assert_equal( actual_mean, expected_mean, prec=0.2, msg="".join( [ "\nexpected {}".format(expected_mean.detach().cpu().numpy()), "\n actual {}".format(actual_mean.detach().cpu().numpy()), ] ), )
def quantiles(self, quantiles, *args, **kwargs): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles([0.05, 0.5, 0.95])) :param quantiles: A list of requested quantiles between 0 and 1. :type quantiles: torch.Tensor or list :return: A dict mapping sample site name to a list of quantile values. :rtype: dict """ results = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, site_scale = self._get_loc_and_scale(name) site_quantiles = torch.tensor(quantiles, dtype=site_loc.dtype, device=site_loc.device) site_quantiles_values = dist.Normal( site_loc, site_scale).icdf(site_quantiles) constrained_site_quantiles = biject_to( site["fn"].support)(site_quantiles_values) results[name] = constrained_site_quantiles return results
def _get_sample_fn(module, name): if module.mode == "model": return module._priors[name] dist_constructor, dist_args = module._guides[name] if dist_constructor is dist.Delta: p_map = getattr(module, "{}_map".format(name)) return dist.Delta(p_map, event_dim=p_map.dim()) # create guide dist_args = { arg: getattr(module, "{}_{}".format(name, arg)) for arg in dist_args } guide = dist_constructor(**dist_args) # no need to do transforms when support is real (for mean field ELBO) support = module._priors[name].support if _is_real_support(support): return guide.to_event() # otherwise, we do inference in unconstrained space and transform the value # back to original space # TODO: move this logic to infer.autoguide or somewhere else unconstrained_value = pyro.sample(module._pyro_get_fullname( "{}_latent".format(name)), guide.to_event(), infer={"is_auxiliary": True}) transform = biject_to(support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) return dist.Delta(value, log_density.sum(), event_dim=value.dim())
def sample(self, guide_name, fn, infer=None): """ Wrapper around ``pyro.sample()`` to create a single auxiliary sample site and then unpack to multiple sample sites for model replay. :param str guide_name: The name of the auxiliary guide site. :param callable fn: A distribution with shape ``self.event_shape``. :param dict infer: Optional inference configuration dict. :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the single concatenated blob and ``model_zs`` is a dict mapping site name to constrained model sample. :rtype: tuple """ # Sample a packed tensor. if fn.event_shape != self.event_shape: raise ValueError( "Invalid fn.event_shape for group: expected {}, actual {}". format(tuple(self.event_shape), tuple(fn.event_shape))) if infer is None: infer = {} infer["is_auxiliary"] = True guide_z = pyro.sample(guide_name, fn, infer=infer) common_batch_shape = guide_z.shape[:-1] model_zs = {} pos = 0 for site in self.prototype_sites: name = site["name"] fn = site["fn"] # Extract slice from packed sample. size = self._site_sizes[name] batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) unconstrained_z = guide_z[..., pos:pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size # Transform to constrained space. transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian( z, unconstrained_z) log_density = sum_rightmost( log_density, log_density.dim() - z.dim() + fn.event_dim) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.guide.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) model_zs[name] = pyro.sample(name, delta_dist) return guide_z, model_zs
def _get_group(self, match='.*'): """Return group and unconstrained initial values.""" group = self.group(match=match) z = [] for site in group.prototype_sites: constrained_z = self.init(site) transform = biject_to(site['fn'].support) z.append(transform.inv(constrained_z).reshape(-1)) z_init = torch.cat(z, 0) event_mask = [] for site in group.prototype_sites: site_shape = site['fn'].batch_shape + site['fn'].event_shape if isinstance( site['fn'], pyro.distributions.torch_distribution.MaskedDistribution): mask = site['fn']._mask.expand(site_shape).flatten() else: mask = torch.full([site_shape.numel()], True, dtype=torch.bool, device=z_init.device) event_mask.append(mask) group.event_mask = torch.cat(event_mask) return group, z_init
def _transform_values( self, aux_values: Dict[str, torch.Tensor], ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: # Learnably transform auxiliary values to user-facing values. values = {} log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False for name, site in self._factors.items(): if site["is_observed"]: continue loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) unconstrained = aux_values[name] * scale + loc # Transform to constrained space. transform = biject_to(site["fn"].support) values[name] = transform(unconstrained) if compute_density: assert transform.codomain.event_dim == site["fn"].event_dim log_densities[name] = transform.inv.log_abs_det_jacobian( values[name], unconstrained ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1) return values, log_densities
def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = { name: node["value"] for name, node in trace.iter_stochastic_nodes() } for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc)
def _get_params(self, name: str, prior: Distribution): try: loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) return loc, scale except AttributeError: pass # Initialize. with torch.no_grad(): transform = biject_to(prior.support) event_dim = transform.domain.event_dim constrained = self.init_loc_fn({ "name": name, "fn": prior }).detach() unconstrained = transform.inv(constrained) # Initialize the distribution to be an affine combination: # init_scale * prior + (1 - init_scale) * init_loc init_loc = self._adjust_plates(unconstrained, event_dim) init_loc = init_loc * (1 - self._init_scale) init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self, "scales." + name, PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim), ) return self._get_params(name, prior)
def guide(self): """Approximate posterior for the horseshoe prior. We assume posterior in the form of the multivariate normal distriburtion for the global mean and standard deviation and multivariate normal distribution for the parameters of each subject independently. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2 * npar)) st_hyp = param('scale_tril_hyp', torch.eye(2 * npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[..., :npar] unc_tau = hyp[..., npar:] c_tau = trns(unc_tau) ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau) ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1) sample("mu", dist.Delta(unc_mu, event_dim=1)) sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1)) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('runs', nsub): sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.locs = PyroModule() self.scales = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = getattr(frame, "full_size", frame.size) if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) _deep_setattr(self.scales, name, PyroParam(init_scale, self.scale_constraint, event_dim))
def test_auto_dirichlet(auto_class, Elbo): num_steps = 2000 prior = torch.tensor([0.5, 1.0, 1.5, 3.0]) data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long() posterior = torch.tensor([4.5, 3.0, 6.5, 3.0]) def model(data): p = pyro.sample("p", dist.Dirichlet(prior)) with pyro.plate("data_plate"): pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data) guide = auto_class(model) svi = SVI(model, guide, optim.Adam({"lr": .003}), loss=Elbo()) for _ in range(num_steps): loss = svi.step(data) assert np.isfinite(loss), loss expected_mean = posterior / posterior.sum() actual_mean = biject_to(constraints.simplex)(pyro.param("auto_loc")) assert_equal(actual_mean, expected_mean, prec=0.2, msg=''.join([ '\nexpected {}'.format( expected_mean.detach().cpu().numpy()), '\n actual {}'.format(actual_mean.detach().cpu().numpy()) ]))
def bijection(self) -> Transform: """ Returns a bijected function for transforms from unconstrained to constrained space. """ if not self.trainable: raise ValueError('Is not of `Distribution` instance!') return biject_to(self._prior.support)
def median(self, *args, **kwargs): result = {} for name, site in self._sorted_sites: loc = deep_getattr(self.locs, name).detach() shape = self._batch_shapes[ name] + self._unconstrained_event_shapes[name] loc = loc.reshape(shape) result[name] = biject_to(site["fn"].support)(loc) return result
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) encoded_hidden = self.encode(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): transform = biject_to(site["fn"].support) with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden) unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( site_loc, site_scale, ).to_event(self._event_dims[name]), infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) if pyro.poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_latent, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) result[name] = pyro.sample(name, delta_dist) return result
def _get_posterior_median(self, name, prior): transform = biject_to(prior.support) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): loc, scale, weight = self._get_params(name, prior) loc = loc + transform.inv(prior.mean) * weight else: loc, scale = self._get_params(name, prior) return transform(loc)
def _get_mutual_information(self, name, prior): """Approximate the mutual information between data x and latent variable z I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) Returns: Float """ #### get posterior mean and variance #### transform = biject_to(prior.support) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): loc, scale, weight = self._get_params(name, prior) loc = loc + transform.inv(prior.mean) * weight else: loc, scale = self._get_params(name, prior) if name not in self.amortised_plate_sites["sites"].keys(): # if amortisation is not used for a particular site return MI=0 return 0 #### create tensors with useful numbers #### one = torch.ones((), dtype=loc.dtype, device=loc.device) two = torch.tensor(2, dtype=loc.dtype, device=loc.device) pi = torch.tensor(3.14159265359, dtype=loc.dtype, device=loc.device) #### get sample from posterior #### z_samples = self.samples_for_mi[name] #### compute mi #### x_batch, nz = loc.size() x_batch = torch.tensor(x_batch, dtype=loc.dtype, device=loc.device) nz = torch.tensor(nz, dtype=loc.dtype, device=loc.device) # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+scale.loc()).sum(-1) neg_entropy = (-nz * torch.log(pi * two) * (one / two) - ((scale**two).log() + one).sum(-1) * (one / two)).mean() # [1, x_batch, nz] loc, scale = loc.unsqueeze(0), scale.unsqueeze(0) var = scale**two # (z_batch, x_batch, nz) dev = z_samples - loc # (z_batch, x_batch) log_density = -((dev**two) / var).sum(dim=-1) * ( one / two) - (nz * torch.log(pi * two) + (scale**two).log().sum(-1)) * (one / two) # log q(z): aggregate posterior # [z_batch] log_qz = log_sum_exp(log_density, dim=1) - torch.log(x_batch) return (neg_entropy - log_qz.mean(-1)).item()
def get_named_particles(self): """ Create a dictionary mapping name to vectorized value, of the form ``{name: tensor}``. The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays. """ return { site["name"]: biject_to(site["fn"].support)(unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent( pyro.param("svgd_particles")) }
def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ loc, _ = self._loc_scale(*args, **kwargs) return {site["name"]: biject_to(site["fn"].support)(unconstrained_value) for site, unconstrained_value in self._unpack_latent(loc)}
def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ loc, scale = self._loc_scale(*args, **kwargs) return {site["name"]: biject_to(site["fn"].support)(unconstrained_value) for site, unconstrained_value in self._unpack_latent(loc)}
def __init__(self, base_dist, **parameters): super().__init__() self.base_dist = base_dist for k, v in parameters.items(): self.register_buffer( k, v if isinstance(v, torch.Tensor) else torch.tensor(v)) self.bijection = biject_to(self().support) self.shape = self().event_shape
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "NeuTraReparam does not support observe statements.") log_density = 0.0 compute_density = poutine.get_mask() is not False if name not in self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) z_unconstrained = pyro.sample( f"{name}_shared_latent", self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = { site["name"]: (site, unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent( x_unconstrained) } # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def _init_loc(self): """ Creates an initial latent vector using a per-site init function. """ parts = [] for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_value = site["value"].detach() unconstrained_value = biject_to(site["fn"].support).inv(constrained_value) parts.append(unconstrained_value.reshape(-1)) latent = torch.cat(parts) assert latent.size() == (self.latent_dim,) return latent
def get_posterior( self, name: str, prior: Distribution) -> Union[Distribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) affine = dist.transforms.AffineTransform( loc, scale, event_dim=transform.domain.event_dim, cache_size=1) posterior = dist.TransformedDistribution( prior, [transform.inv.with_cache(), affine, transform.with_cache()]) return posterior
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) self._unconstrained_shapes = {} self._cond_indep_stacks = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue # Collect the shapes of unconstrained values. # These may differ from the shapes of constrained values. self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] self.latent_dim = sum(_product(shape) for shape in self._unconstrained_shapes.values()) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__))
def quantiles(self, quantiles, *args, **kwargs): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles([0.05, 0.5, 0.95])) :param quantiles: A list of requested quantiles between 0 and 1. :type quantiles: torch.Tensor or list :return: A dict mapping sample site name to a list of quantile values. :rtype: dict """ loc, scale = self._loc_scale(*args, **kwargs) quantiles = loc.new_tensor(quantiles).unsqueeze(-1) latents = dist.Normal(loc, scale).icdf(quantiles) result = {} for latent in latents: for site, unconstrained_value in self._unpack_latent(latent): result.setdefault(site["name"], []).append(biject_to(site["fn"].support)(unconstrained_value)) return result