def guide(self): """Approximate posterior for the horseshoe prior. We assume posterior in the form of the multivariate normal distriburtion for the global mean and standard deviation and multivariate normal distribution for the parameters of each subject independently. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2 * npar)) st_hyp = param('scale_tril_hyp', torch.eye(2 * npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[..., :npar] unc_tau = hyp[..., npar:] c_tau = trns(unc_tau) ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau) ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1) sample("mu", dist.Delta(unc_mu, event_dim=1)) sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1)) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('runs', nsub): sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def new_guide(obsmat): # These are just the previous values we can use to initialize params here initial_topic_weights = pyro.get_param_store()['AutoDelta.topic_weights'] initial_alpha = pyro.get_param_store()['AutoDelta.topic_weights'] initial_topic_a = pyro.get_param_store()['AutoDelta.topic_a'] initial_topic_b = pyro.get_param_store()['AutoDelta.topic_b'] # Use poutine.block to Keep our learned values of global parameters. with poutine.block(hide_types=["param"]): # This has to match the structure of the model with pyro.plate('topic', tm.K): # We manually define the AutoDelta params we had from before here topic_weights_q = pyro.param('AutoDelta.topic_weights', initial_topic_weights) topic_a_q = pyro.param('AutoDelta.topic_a', initial_topic_a) topic_b_q = pyro.param('AutoDelta.topic_b', initial_topic_b) # Each of the sample statements in the above model needs to have a corresponding # statement here where we insert our tuneable params pyro.sample("topic_weights", dist.Delta(topic_weights_q)) pyro.sample('topic_a', dist.Delta(topic_a_q).to_event(2)) pyro.sample('topic_b', dist.Delta(topic_b_q).to_event(2)) # We define a new learnable parameter for the new participant that # sums to 1 (via constraint) and plug this in as their topic probabilities probs = pyro.param('new_participant_topic_q', initial_alpha, constraint=constraints.simplex) participant_topics = pyro.sample("new_participant_topic", dist.Delta(probs).to_event(1))
def _get_sample_fn(module, name): if module.mode == "model": return module._priors[name] dist_constructor, dist_args = module._guides[name] if dist_constructor is dist.Delta: p_map = getattr(module, "{}_map".format(name)) return dist.Delta(p_map, event_dim=p_map.dim()) # create guide dist_args = { arg: getattr(module, "{}_{}".format(name, arg)) for arg in dist_args } guide = dist_constructor(**dist_args) # no need to do transforms when support is real (for mean field ELBO) support = module._priors[name].support if _is_real_support(support): return guide.to_event() # otherwise, we do inference in unconstrained space and transform the value # back to original space # TODO: move this logic to infer.autoguide or somewhere else unconstrained_value = pyro.sample(module._pyro_get_fullname( "{}_latent".format(name)), guide.to_event(), infer={"is_auxiliary": True}) transform = biject_to(support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) return dist.Delta(value, log_density.sum(), event_dim=value.dim())
def test_batch_log_prob(self): log_px_torch = dist.Delta(self.vs_expanded).log_prob( self.batch_test_data_1).data assert_equal(log_px_torch.sum().item(), 0) log_px_torch = dist.Delta(self.vs_expanded).log_prob( self.batch_test_data_2).data assert_equal(log_px_torch.sum().item(), float('-inf'))
def MAP_guide(prior_params, logits, labels): """ Defines a guide for use in MAP inference. """ n_cls = logits.shape[1] # Num classes beta_MAP = pyro.param('beta_MAP', torch.ones(n_cls, requires_grad=True)) delta_MAP = pyro.param('delta_MAP', torch.zeros(n_cls, requires_grad=True)) pyro.sample('beta', dist.Delta(beta_MAP)) pyro.sample('delta', dist.Delta(delta_MAP))
def setUp(self): self.v = Variable(torch.Tensor([3])) self.vs = Variable(torch.Tensor([[0], [1], [2], [3]])) self.test_data = Variable(torch.Tensor([3, 3, 3])) self.batch_test_data = Variable( torch.arange(0, 4).unsqueeze(1).expand(4, 3)) self.dist = dist.Delta(self.v) self.batch_dist = dist.Delta(self.vs, batch_size=2)
def templates_guide_mvn(self): """ Multivariate normal guide for template parameters """ loc = _deep_getattr(self, "mvn.loc") scale_tril = _deep_getattr(self, "mvn.scale_tril") dt = dist.MultivariateNormal(loc, scale_tril=scale_tril) states = pyro.sample("states_" + self.name_prefix, dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.poiss_labels[i_poiss]] = pyro.sample( self.poiss_labels[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps]] = pyro.sample( self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def module_ppca_gm_means_sigma_guide(self, input_batch, epsilon): batch_size = input_batch.shape[0] if self.likelihood == 'normal': if self.group_isotropic: ppca_gm_sigma_p = pyro.param(f'ppca_gm_sigma_p', input_batch.new_ones(1, 1), constraint=constraints.positive) ppca_gm_sigma = pyro.sample( f'ppca_gm_sigma', dist.Delta(ppca_gm_sigma_p).independent(1)) else: ppca_gm_sigma = input_batch.new_ones(1, 1) ppca_gm_sigma_list = [] for i in range(self.d): ppca_gm_sigma_p = pyro.param( f'ppca_gm_sigma_{i}_p', input_batch.new_ones(1, self.n[i]), constraint=constraints.positive) ppca_gm_sigma_list.append( pyro.sample( f'ppca_gm_sigma_{i}', dist.Delta(ppca_gm_sigma_p).independent(1))) ppca_gm_sigma = torch_utils.krp_cw_torch( ppca_gm_sigma_list[i], ppca_gm_sigma, column=False) else: ppca_gm_sigma = input_batch.new_ones(1, 1) ppca_gm_sigma_list = [ input_batch.new_ones(1, self.n[i]) for i in range(self.d) ] alpha_gm_p = pyro.param( f'alpha_gm_p', input_batch.new_ones([1, self.group_hidden_dim])) alpha_gm = pyro.sample(f'alpha_gm', dist.Delta(alpha_gm_p).independent(1)) if self.group_iterm is None: z_mu = self.group_term.linear_mapping.inverse_batch(input_batch) else: z_mu = self.group_iterm(torch_utils.flatten_torch(input_batch), T=True) if self.group_isotropic: zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=ppca_gm_sigma[0] if ppca_gm_sigma is not None else input_batch.new_ones(1), z_mu=z_mu, z_sigma=alpha_gm[0]) else: zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=[x for x in ppca_gm_sigma_list], z_mu=z_mu, z_sigma=alpha_gm[0]) ppca_gm_means = self.group_term( zk_mean + epsilon[:, :self.group_hidden_dim].mm( zk_cov.view(self.group_hidden_dim, self.group_hidden_dim))) return ppca_gm_means, ppca_gm_sigma
def guide(self): a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0)) a_scales_tril = pyro.param( "a_scales", lambda: 0.1 * eye_like(a_locs, self.n_params), constraint=constraints.lower_cholesky) dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril) states = pyro.sample("states", dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.labels_poiss[i_poiss]] = pyro.sample( self.labels_poiss[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps]] = pyro.sample( self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def reward(state, i=0): """Reward function given a state""" # Goal is state 15, reward 1 point if state == 15: return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(1.))) # Holes are state 5, 7, 11, 12, penalize 15 points if state in [5, 7, 11, 12]: return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(-10.))) # Create a reward that grows as we get close to goal r = float(1 / (15 - state + 1)) return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(r)))
def guide(self, evidence={}, noise=None): """A "smart" guide function for the SCM model above which propagates the information from a deterministic node being observed to the noise node, so that you don't end up with many rejected samples. This is slightly different from the model schema for the sake of sampling efficiency. Args: evidence (dict): a dictionary of {node_name: value} evidence data. noise (None): a useless parameter that exists because in Pyro, the guide fn has the same inputs as the model fn. Returns: model_dict (dict): a sample from the guide in {node_name: value} format. TODO: Make all endogenous nodes deterministic rather than delta variables """ guide_dict = {} # the order is a little complex. Any observed nodes have to go first, then the non-twin endog, then twin. for node in self._get_guide_order(evidence): exog_parent = [ n for n in self.G_inference.predecessors(node) if self.scm._is_exog(n, self.G_inference) ][0] endog_parents = sorted([ n for n in self.G_inference.predecessors(node) if not self.scm._is_exog(n, self.G_inference) ]) if endog_parents: parent_values = [guide_dict[n] for n in endog_parents] else: parent_values = [] if node not in evidence: if exog_parent not in guide_dict: # if you haven't already sampled the exog_parent guide_dict[exog_parent] = pyro.sample( exog_parent, self.exog_fn) else: if not endog_parents: # if node only has an exogenous parent if exog_parent not in guide_dict: guide_dict[exog_parent] = pyro.sample( exog_parent, dist.Delta(evidence[node])) # TODO: Choose # guide_dict[exog_parent] = self._assign_delta_node(exog_parent, evidence[node]) else: # if a node has exog & endog parents if exog_parent not in guide_dict: predicted_val = self._scm_function(node, parent_values) exog_val = self.invert_fn(evidence[node], predicted_val) guide_dict[exog_parent] = pyro.sample( exog_parent, dist.Delta(exog_val)) # TODO: Choose # guide_dict[exog_parent] = self._assign_delta_node(exog_parent, exog_val) val = self._scm_function(node, parent_values, guide_dict[exog_parent]) guide_dict[node] = pyro.sample(node, dist.Delta(val)) # TODO: Choose # guide_dict[node] = self._assign_delta_node(node, val) return guide_dict
def guide(): # only contains param and sample from Delta dist with av_price_log_mean_param = pyro.param('average_price_log_mean_param', torch.tensor(1.0)) av_price_log_var_param = pyro.param('average_price_log_var_param', torch.tensor(1.0)) return ( pyro.sample('average_price_log_mean', dist.Delta(av_price_log_mean_param)), pyro.sample('average_price_log_var', dist.Delta(av_price_log_var_param)), )
def guide_map(anime_matrix_train, k=k): m = anime_matrix_train.shape[0] n = anime_matrix_train.shape[1] u_map = pyro.param('u_map', torch.zeros([m, k])) v_map = pyro.param('v_map', torch.zeros([n, k])) sigma_map = pyro.param("sigma_map", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("u", dist.Delta(u_map).to_event(2)) pyro.sample("v", dist.Delta(v_map).to_event(2)) pyro.sample("sigma", dist.Delta(sigma_map))
def guide_horseshoe_plus(self): npar = self.npars # number of parameters nsub = self.runs # number of subjects trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2*npar)) st_hyp = param('scale_tril_hyp', torch.eye(2*npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[:npar] unc_sigma = hyp[npar:] c_sigma = trns(unc_sigma) ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma) ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1) mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1)) sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1)) m_tmp = param('m_tmp', zeros(nsub, 2*npar)) st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('subjects', nsub): tmp = sample('tmp', dist.MultivariateNormal(m_tmp, scale_tril=st_tmp), infer={'is_auxiliary': True}) unc_locs = tmp[..., :npar] unc_scale = tmp[..., npar:] c_scale = trns(unc_scale) ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale) ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1) x = sample("x", dist.Delta(unc_locs, event_dim=1)) sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1)) return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
def forward(self, design, target_labels=None): """ Sample the posterior. :param torch.Tensor design: tensor of possible designs. :param list target_labels: list indicating the sample sites that are targets, i.e. for which information gain should be measured. """ if target_labels is None: target_labels = list(self.means.keys()) pyro.module("laplace_guide", self) with ExitStack() as stack: for plate in iter_plates_to_shape(design.shape[:-2]): stack.enter_context(plate) if self.training: # MAP via Delta guide for l in target_labels: w_dist = dist.Delta(self.means[l]).to_event(1) pyro.sample(l, w_dist) else: # Laplace approximation via MVN with hessian for l in target_labels: w_dist = dist.MultivariateNormal( self.means[l], scale_tril=self.scale_trils[l]) pyro.sample(l, w_dist)
def __call__(self, name, fn, obs): assert fn.event_dim >= self.event_dim assert obs is None, "SplitReparam does not support observe statements" # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim:] parts = [] for i, size in enumerate(self.sections): event_shape = left_shape + (size, ) + right_shape parts.append( pyro.sample( "{}_split_{}".format(name, i), dist.ImproperUniform(fn.support, fn.batch_shape, event_shape))) value = torch.cat(parts, dim=-self.event_dim) # Combine parts. if poutine.get_mask() is False: log_density = 0.0 else: log_density = fn.log_prob(value) new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_density) return new_fn, value
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if is_observed: raise NotImplementedError( "ProjectedNormalReparam does not support observe statements" ) fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal) # Differentiably invert transform. value_normal = None if value is not None: # We use an arbitrary injection, which works only for initialization. value_normal = value - fn.concentration # Draw parameter-free noise. new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1) x = pyro.sample( "{}_normal".format(name), self._wrap(new_fn, event_dim), obs=value_normal, infer={"is_observed": is_observed}, ) # Differentiably transform. if value is None: value = safe_normalize(x + fn.concentration) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}
def test_kl_delta_normal_shape(batch_shape): v = torch.randn(batch_shape) loc = torch.randn(batch_shape) scale = torch.randn(batch_shape).exp() p = dist.Delta(v) q = dist.Normal(loc, scale) assert kl_divergence(p, q).shape == batch_shape
def map_estimate(self, name): """ Construct a maximum a posteriori (MAP) guide using Delta distributions. :param str name: The name of a model sample site. :return: A sampled value. :rtype: torch.Tensor """ site = self.prototype_trace.nodes[name] fn = site["fn"] event_dim = fn.event_dim init_needed = not hasattr(self, name) if init_needed: init_value = site["value"].detach() with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) elif init_needed and plate.subsample_size < plate.size: # Repeat the init_value to full size. dim = plate.dim - event_dim assert init_value.size(dim) == plate.subsample_size ind = torch.arange(plate.size, device=init_value.device) ind = ind % plate.subsample_size init_value = init_value.index_select(dim, ind) if init_needed: setattr(self, name, PyroParam(init_value, fn.support, event_dim)) value = getattr(self, name) return pyro.sample(name, dist.Delta(value, event_dim=event_dim))
def unpack(self, group_z: torch.Tensor) -> tp.Dict[str, torch.Tensor]: model_zs = {} for pos, (name, fn, frames), transform in zip( # lazy cumsum!! python ftw! (s for s in [0] for x in self.sizes.values() for s in [x+s]), map(itemgetter('name', 'fn', 'cond_indep_stack'), self.sites.values()), self.transforms.values() ): fn: dist.TorchDistribution zs = group_z[..., pos-self.sizes[name]:pos] z = self.inits[name].expand(zs.shape[:-1] + self.masks[name].shape).clone() z[..., self.masks[name]] = zs x = transform(z) if self.include_det_jac and transform.bijective: log_density = transform.inv.log_abs_det_jacobian(x, z) log_density = log_density.sum(list(range(-(log_density.ndim - z.ndim + fn.event_dim), 0))) else: log_density = 0. delta = dist.Delta(x, log_density=log_density, event_dim=fn.event_dim) model_zs[name] = pyro.sample(name, delta) return model_zs
def get_posterior(self, *args, **kwargs): """ Returns a Delta posterior distribution for MAP inference. """ loc = pyro.param("{}_loc".format(self.prefix), lambda: torch.zeros(self.latent_dim)) return dist.Delta(loc).to_event(1)
def parametrized_guide(predictor, data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive, ) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5), ) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) with pyro.plate("documents", args.num_docs, batch_size) as ind: data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(args.num_words, ind.size(0)).scatter_add(0, data, torch.ones(data.shape)) doc_topics = predictor(counts.transpose(0, 1)) pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) attr_get = operator.attrgetter(name) result[name] = pyro.sample(name, dist.Delta(attr_get(self), event_dim=site["fn"].event_dim)) return result
def apply(self, msg): name = msg["name"] fn = msg["fn"] # ignore msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.Stable) and fn.coords == "S0" if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "LatentStableReparam does not support observe statements") # Draw parameter-free noise. proto = fn.stability half_pi = proto.new_tensor(math.pi / 2) one = proto.new_ones(proto.shape) u = pyro.sample( "{}_uniform".format(name), self._wrap( dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim), ) e = pyro.sample("{}_exponential".format(name), self._wrap(dist.Exponential(one), event_dim)) # Differentiably transform. x = _standard_stable(fn.stability, fn.skew, u, e, coords="S0") value = fn.loc + fn.scale * x # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates() # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def parametrized_guide(predictor, data, num_words_per_doc, args): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) for doc in pyro.plate("documents", args.num_docs, args.batch_size): # data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(args.num_words, 1) for i in data[doc]: counts[i] += 1 # .scatter_add(0, data[doc], torch.ones(data[doc].shape))) doc_topics = predictor(counts.transpose(0, 1)) pyro.sample("doc_topics_{}".format(doc), dist.Delta(doc_topics, event_dim=1)) # added this part since with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]): word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
def sample(self, guide_name, fn, infer=None): """ Wrapper around ``pyro.sample()`` to create a single auxiliary sample site and then unpack to multiple sample sites for model replay. :param str guide_name: The name of the auxiliary guide site. :param callable fn: A distribution with shape ``self.event_shape``. :param dict infer: Optional inference configuration dict. :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the single concatenated blob and ``model_zs`` is a dict mapping site name to constrained model sample. :rtype: tuple """ # Sample a packed tensor. if fn.event_shape != self.event_shape: raise ValueError( "Invalid fn.event_shape for group: expected {}, actual {}". format(tuple(self.event_shape), tuple(fn.event_shape))) if infer is None: infer = {} infer["is_auxiliary"] = True guide_z = pyro.sample(guide_name, fn, infer=infer) common_batch_shape = guide_z.shape[:-1] model_zs = {} pos = 0 for site in self.prototype_sites: name = site["name"] fn = site["fn"] # Extract slice from packed sample. size = self._site_sizes[name] batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) unconstrained_z = guide_z[..., pos:pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size # Transform to constrained space. transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian( z, unconstrained_z) log_density = sum_rightmost( log_density, log_density.dim() - z.dim() + fn.event_dim) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.guide.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) model_zs[name] = pyro.sample(name, delta_dist) return guide_z, model_zs
def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): return name, fn, obs event_shape = fn.event_shape fn, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = pyro.param("{}_centered", lambda: fn.loc.new_full(event_shape, 0.5), constraint=constraints.unit_interval) params["loc"] = fn.loc * centered params["scale"] = fn.scale ** centered decentered_fn = type(fn)(**params) # Draw decentered noise. decentered_value = pyro.sample("{}_decentered".format(name), self._wrap(decentered_fn, event_dim)) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + fn.scale.pow(1 - centered) * delta # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return new_fn, value
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "StructuredReparam does not support observe statements") if name not in self.deltas: # On first sample site. with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) self.deltas = self.guide.get_deltas() new_fn = self.deltas.pop(name) value = new_fn.v if poutine.get_mask() is not False: log_density = new_fn.log_density + fn.log_prob(value) new_fn = dist.Delta(value, log_density, new_fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def belief_value_model(belief, action, t, discount=1.0, discount_factor=0.95, max_depth=10, bu_nsteps=10, bu_lr=0.1): """Returns Pr(Value | b,a)""" if t > max_depth: return tensor(1e-9) # Somehow compute the value state = states[pyro.sample("s%d" % t, belief)] next_state = states[pyro.sample("next_s%d" % t, transition_dist(state, action))] reward = pyro.sample("r%d" % t, reward_dist(state, action, next_state)) if next_state == "terminal": return pyro.sample("v%d" % t, dist.Delta(reward)) else: # compute future value discount = discount*discount_factor observation = observations[pyro.sample("o%d" % t, observation_dist(next_state, action))] with poutine.block(hide_fn=lambda site: site["name"].startswith("bu")): next_belief = belief_update(belief, action, observation, num_steps=bu_nsteps, lr=bu_lr, suffix=str(t)) # action_weights = pyro.param("action_weights", action_weights) next_action = belief_policy_model(next_belief, t+1, discount=discount, discount_factor=discount_factor, max_depth=max_depth) return reward + discount*belief_value_model(next_belief, next_action, t+1, discount=discount, discount_factor=discount_factor, max_depth=max_depth, bu_nsteps=bu_nsteps, bu_lr=bu_lr)