def templates_guide_mvn(self): """ Multivariate normal guide for template parameters """ loc = _deep_getattr(self, "mvn.loc") scale_tril = _deep_getattr(self, "mvn.scale_tril") dt = dist.MultivariateNormal(loc, scale_tril=scale_tril) states = pyro.sample("states_" + self.name_prefix, dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.poiss_labels[i_poiss]] = pyro.sample( self.poiss_labels[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps]] = pyro.sample( self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def score_parts(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) log_prob, score_function, entropy_term = self.base_dist.score_parts(value) log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(score_function, numbers.Number): score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(entropy_term, numbers.Number): entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape) return ScoreParts(log_prob, score_function, entropy_term)
def guide(self): a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0)) a_scales_tril = pyro.param( "a_scales", lambda: 0.1 * eye_like(a_locs, self.n_params), constraint=constraints.lower_cholesky) dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril) states = pyro.sample("states", dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.labels_poiss[i_poiss]] = pyro.sample( self.labels_poiss[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps]] = pyro.sample( self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def guide_horseshoe_plus(self): npar = self.npars # number of parameters nsub = self.runs # number of subjects trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2*npar)) st_hyp = param('scale_tril_hyp', torch.eye(2*npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[:npar] unc_sigma = hyp[npar:] c_sigma = trns(unc_sigma) ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma) ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1) mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1)) sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1)) m_tmp = param('m_tmp', zeros(nsub, 2*npar)) st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('subjects', nsub): tmp = sample('tmp', dist.MultivariateNormal(m_tmp, scale_tril=st_tmp), infer={'is_auxiliary': True}) unc_locs = tmp[..., :npar] unc_scale = tmp[..., npar:] c_scale = trns(unc_scale) ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale) ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1) x = sample("x", dist.Delta(unc_locs, event_dim=1)) sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1)) return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
def guide(self): """Approximate posterior for the horseshoe prior. We assume posterior in the form of the multivariate normal distriburtion for the global mean and standard deviation and multivariate normal distribution for the parameters of each subject independently. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2 * npar)) st_hyp = param('scale_tril_hyp', torch.eye(2 * npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[..., :npar] unc_tau = hyp[..., npar:] c_tau = trns(unc_tau) ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau) ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1) sample("mu", dist.Delta(unc_mu, event_dim=1)) sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1)) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('runs', nsub): sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def _kl_independent_independent(p, q): if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: raise NotImplementedError kl = kl_divergence(p.base_dist, q.base_dist) if p.reinterpreted_batch_ndims: kl = sum_rightmost(kl, p.reinterpreted_batch_ndims) return kl
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates() # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def sample(self, guide_name, fn, infer=None): """ Wrapper around ``pyro.sample()`` to create a single auxiliary sample site and then unpack to multiple sample sites for model replay. :param str guide_name: The name of the auxiliary guide site. :param callable fn: A distribution with shape ``self.event_shape``. :param dict infer: Optional inference configuration dict. :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the single concatenated blob and ``model_zs`` is a dict mapping site name to constrained model sample. :rtype: tuple """ # Sample a packed tensor. if fn.event_shape != self.event_shape: raise ValueError( "Invalid fn.event_shape for group: expected {}, actual {}". format(tuple(self.event_shape), tuple(fn.event_shape))) if infer is None: infer = {} infer["is_auxiliary"] = True guide_z = pyro.sample(guide_name, fn, infer=infer) common_batch_shape = guide_z.shape[:-1] model_zs = {} pos = 0 for site in self.prototype_sites: name = site["name"] fn = site["fn"] # Extract slice from packed sample. size = self._site_sizes[name] batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) unconstrained_z = guide_z[..., pos:pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size # Transform to constrained space. transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian( z, unconstrained_z) log_density = sum_rightmost( log_density, log_density.dim() - z.dim() + fn.event_dim) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.guide.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) model_zs[name] = pyro.sample(name, delta_dist) return guide_z, model_zs
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) iaranges = self._create_iaranges() # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(iaranges[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def test_kl_independent_normal(batch_shape, event_shape): shape = batch_shape + event_shape p = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) q = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) actual = kl_divergence(dist.Independent(p, len(event_shape)), dist.Independent(q, len(event_shape))) expected = sum_rightmost(kl_divergence(p, q), len(event_shape)) assert_close(actual, expected)
def _kl_transformed_transformed(p, q): if p.transforms != q.transforms: raise NotImplementedError if p.event_shape != q.event_shape: raise NotImplementedError extra_event_dim = len(p.base_dist.batch_shape) - len(p.batch_shape) base_kl_divergence = kl_divergence(p.base_dist, q.base_dist) return sum_rightmost(base_kl_divergence, extra_event_dim)
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) encoded_hidden = self.encode(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): transform = biject_to(site["fn"].support) with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden) unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( site_loc, site_scale, ).to_event(self._event_dims[name]), infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) if pyro.poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_latent, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) result[name] = pyro.sample(name, delta_dist) return result
def conjugate_update(self, other): """ EXPERIMENTAL. """ n = self.reintepreted_batch_ndims updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n)) updated = updated.to_event(n) log_normalizer = sum_rightmost(log_normalizer, n) return updated, log_normalizer
def _kl_independent_independent(p, q): shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims) p_ndims = p.reinterpreted_batch_ndims - shared_ndims q_ndims = q.reinterpreted_batch_ndims - shared_ndims p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist kl = kl_divergence(p, q) if shared_ndims: kl = sum_rightmost(kl, shared_ndims) return kl
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if name not in self.guide.prototype_trace.nodes: return {"fn": fn, "value": value, "is_observed": is_observed} if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " "NeuTraReparam does not support observe statements.") log_density = 0.0 compute_density = poutine.get_mask() is not False if name not in self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context( block_plate(dim=plate.dim, strict=False)) z_unconstrained = pyro.sample( f"{name}_shared_latent", self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = { site["name"]: (site, unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent( x_unconstrained) } # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return {"fn": new_fn, "value": value, "is_observed": True}
def test_sum_rightmost(): x = torch.ones(2, 3, 4) assert sum_rightmost(x, 0).shape == (2, 3, 4) assert sum_rightmost(x, 1).shape == (2, 3) assert sum_rightmost(x, 2).shape == (2, ) assert sum_rightmost(x, -1).shape == (2, ) assert sum_rightmost(x, -2).shape == (2, 3) assert sum_rightmost(x, INF).shape == ()
def test_sum_rightmost(): x = torch.ones(2, 3, 4) assert sum_rightmost(x, 0).shape == (2, 3, 4) assert sum_rightmost(x, 1).shape == (2, 3) assert sum_rightmost(x, 2).shape == (2,) assert sum_rightmost(x, -1).shape == (2,) assert sum_rightmost(x, -2).shape == (2, 3) assert sum_rightmost(x, float('inf')).shape == ()
def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates(*args, **kwargs) # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) if poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_value, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def __call__(self, name, fn, obs): if name not in self.guide.prototype_trace.nodes: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0.0 compute_density = (poutine.get_mask() is not False) if not self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`") from e z_unconstrained = pyro.sample( "{}_shared_latent".format(name), self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = list( reversed(list(self.guide._unpack_latent(x_unconstrained)))) # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop() assert name == site["name"], "model structure changed" transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return new_fn, value
def __call__(self, name, fn, obs): if name not in self.guide.prototype_trace.nodes: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0. if not self.x_unconstrained: # On first sample site. # Sample a shared latent. # TODO(fehiepsi) Consider adding a method to extract transform from an Auto*Normal(posterior). posterior = self.guide.get_posterior() if not isinstance(posterior, dist.TransformedDistribution): raise ValueError( "NeuTraReparam only supports guides whose posteriors are " "TransformedDistributions but got a posterior of type {}". format(type(posterior))) self.transform = dist.transforms.ComposeTransform( posterior.transforms) z_unconstrained = pyro.sample("{}_shared_latent".format(name), posterior.base_dist.mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self.x_unconstrained = list( reversed(list(self.guide._unpack_latent(x_unconstrained)))) # Extract a single site's value from the shared latent. site, unconstrained_value = self.x_unconstrained.pop() assert name == site["name"], "model structure changed" transform = biject_to(fn.support) value = transform(unconstrained_value) logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost(logdet, logdet.dim() - value.dim() + fn.event_dim) log_density = log_density + fn.log_prob(value) + logdet new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim) return new_fn, value
def evaluate_log_posterior_density(model, posterior_samples, baseball_dataset): """ Evaluate the log probability density of observing the unseen data (season hits) given a model and posterior distribution over the parameters. """ _, test, player_names = train_test_split(baseball_dataset) at_bats_season, hits_season = test[:, 0], test[:, 1] with ignore_experimental_warning(): trace = predictive(model, posterior_samples, at_bats_season, hits_season, parallel=True, return_trace=True) # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $, # where $\theta^{i}$ are parameter samples from the model's posterior. trace.compute_log_prob() log_joint = 0. for name, site in trace.nodes.items(): if site["type"] == "sample" and not site_is_subsample(site): # We use `sum_rightmost(x, -1)` to take the sum of all rightmost dimensions of `x` # except the first dimension (which corresponding to the number of posterior samples) site_log_prob_sum = sum_rightmost(site['log_prob'], -1) log_joint += site_log_prob_sum posterior_pred_density = torch.logsumexp(log_joint, dim=0) - math.log(log_joint.shape[0]) logging.info("\nLog posterior predictive density") logging.info("--------------------------------") logging.info("{:.4f}\n".format(posterior_pred_density))
def log_prob(self, x): v = self.v.expand(self.shape()) log_prob = x.new_tensor(x == v).log() log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density
def log_prob(self, x): v = self.v.expand(self.shape()) log_prob = (x == v).type(x.dtype).log() log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density
def templates_guide_iaf(self, indices, gp_sample=None): """ IAF guide for template parameters """ # Number of context variables (GP summary statistics) to condition IAF context_vars = torch.zeros(self.n_poiss + self.n_ps + 2) # context_vars = torch.zeros(2) td = None # IAF transformation either with or without conditioning on GP draw if self.guide_name == "IAF": td = dist.TransformedDistribution(self.base_dist, self.transform) elif self.guide_name == "ConditionalIAF": # Summary stats of GP draw---dor products of Poiss/non-Poiss templates with GP, as well as GP mean and variance context_vars[:self.n_poiss] = ( self.poiss_temps[:, indices] @ gp_sample.exp().double()) / self.n_pix context_vars[self.n_poiss:self.n_poiss + self.n_ps] = (self.ps_temps[:, indices] @ gp_sample.exp().double()) / self.n_pix context_vars[-2] = torch.mean(gp_sample.exp()) context_vars[-1] = torch.var(gp_sample.exp()).sqrt() td = dist.ConditionalTransformedDistribution( self.base_dist, self.transform).condition(context=context_vars) states = pyro.sample("states_" + self.name_prefix, td, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.poiss_labels[i_poiss]] = pyro.sample( self.poiss_labels[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps]] = pyro.sample( self.ps_param_labels[i_ps_param] + "_" + self.ps_labels[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def log_prob(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)