def __init__(self, dataset_size, scale=None, precision=None, event_dim=1, name="", data_name="data"): super().__init__(dataset_size, event_dim=event_dim, name=name, data_name=data_name) if int(scale is None) + int(precision is None) != 1: raise ValueError( "Exactly one of scale and precision must be specified") elif isinstance(scale, (dist.Distribution, torchdist.Distribution)): # if the scale or precision is a distribution, that is used as the prior for a PyroSample. I'm not # completely sure if it is a good idea to allow regular pytorch distributions, since they might not have the # correct event_dim, so perhaps it's safer to check e.g. if the batch shape is empty and raise an error # otherwise precision = PyroSample(prior=dist.TransformedDistribution( scale, transforms.PowerTransform(-2.))) scale = PyroSample(prior=scale) elif isinstance(precision, (dist.Distribution, torchdist.Distribution)): scale = PyroSample(prior=dist.TransformedDistribution( precision, transforms.PowerTransform(-0.5))) precision = PyroSample(prior=precision) else: # nothing to do, precision or scale is a number/tensor/parameter pass self._scale = scale self._precision = precision
def get_posterior( self, name: str, prior: Distribution, ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) if self._computing_quantiles: return self._get_posterior_quantiles(name, prior) if self._computing_mi: # the messenger autoguide needs the output to fit certain dimensions # this is hack which saves MI to self.mi but returns cheap to compute medians self.mi[name] = self._get_mutual_information(name, prior) return self._get_posterior_median(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) # If hierarchical_sites not specified all sites are assumed to be hierarchical if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): loc, scale, weight = self._get_params(name, prior) loc = loc + transform.inv(prior.mean) * weight posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), transform.with_cache(), ) return posterior else: # Fall back to mean field when hierarchical_sites list is not empty and site not in the list. loc, scale = self._get_params(name, prior) posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), transform.with_cache(), ) return posterior
def stable_model(): zero = torch.zeros(2) a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.LogNormal(0, 1)) c = pyro.sample("c", dist.Stable(1.5, 0.0, b, a)) d = pyro.sample("d", dist.Stable(1.5, 0.0, b, 0.0), obs=a) e = pyro.sample("e", dist.Stable(1.5, 0.1, b, a)) f = pyro.sample("f", dist.Stable(1.5, 0.1, b, 0.0), obs=a) g = pyro.sample("g", dist.Stable(1.5, zero, b, a).to_event(1)) h = pyro.sample("h", dist.Stable(1.5, zero, b, 0).to_event(1), obs=a) i = pyro.sample( "i", dist.TransformedDistribution(dist.Stable(1.5, 0, b, a), dist.transforms.ExpTransform()), ) j = pyro.sample( "j", dist.TransformedDistribution(dist.Stable(1.5, 0, b, a), dist.transforms.ExpTransform()), obs=a.exp(), ) k = pyro.sample( "k", dist.TransformedDistribution(dist.Stable( 1.5, zero, b, a), dist.transforms.ExpTransform()).to_event(1), ) l = pyro.sample( "l", dist.TransformedDistribution(dist.Stable( 1.5, zero, b, a), dist.transforms.ExpTransform()).to_event(1), obs=a.exp() + zero, ) return a, b, c, d, e, f, g, h, i, j, k, l
def model(home_id, away_id, score1_obs=None, score2_obs=None): # priors alpha = pyro.sample("alpha", dist.Normal(0.0, 1.0)) sd_att = pyro.sample( "sd_att", dist.TransformedDistribution( dist.StudentT(3.0, 0.0, 2.5), FoldedTransform(), ), ) sd_def = pyro.sample( "sd_def", dist.TransformedDistribution( dist.StudentT(3.0, 0.0, 2.5), FoldedTransform(), ), ) home = pyro.sample("home", dist.Normal(0.0, 1.0)) # home advantage nt = len(np.unique(home_id)) # team-specific model parameters with pyro.plate("plate_teams", nt): attack = pyro.sample("attack", dist.Normal(0, sd_att)) defend = pyro.sample("defend", dist.Normal(0, sd_def)) # likelihood theta1 = torch.exp(alpha + home + attack[home_id] - defend[away_id]) theta2 = torch.exp(alpha + attack[away_id] - defend[home_id]) with pyro.plate("data", len(home_id)): pyro.sample("s1", dist.Poisson(theta1), obs=score1_obs) pyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)
def test_kl_transformed_transformed(shape, event_dim, transform): p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim) q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim) p = dist.TransformedDistribution(p_base, transform) q = dist.TransformedDistribution(q_base, transform) kl = kl_divergence(q, p) expected_shape = shape[:-1] if max(transform.event_dim, event_dim) == 1 else shape assert kl.shape == expected_shape
def random_dist(Dist, shape, transform=None): if Dist is dist.FoldedDistribution: return Dist(random_dist(dist.Normal, shape)) elif Dist is dist.MaskedDistribution: base_dist = random_dist(dist.Normal, shape) mask = torch.empty(shape, dtype=torch.bool).bernoulli_(0.5) return base_dist.mask(mask) elif Dist is dist.TransformedDistribution: base_dist = random_dist(dist.Normal, shape) transforms = [ dist.transforms.ExpTransform(), dist.transforms.ComposeTransform([ dist.transforms.AffineTransform(1, 1), dist.transforms.ExpTransform().inv, ]), ] return dist.TransformedDistribution(base_dist, transforms) elif Dist in (dist.GaussianHMM, dist.LinearHMM): batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] hidden_dim = obs_dim + 1 init_dist = random_dist(dist.Normal, batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_dist(dist.Normal, batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_dist(dist.Normal, batch_shape + (duration, obs_dim)).to_event(1) if Dist is dist.LinearHMM and transform is not None: obs_dist = dist.TransformedDistribution(obs_dist, transform) return Dist(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) elif Dist is dist.IndependentHMM: batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] base_shape = batch_shape + (obs_dim, duration, 1) base_dist = random_dist(dist.GaussianHMM, base_shape) return Dist(base_dist) elif Dist is dist.MultivariateNormal: return random_mvn(shape[:-1], shape[-1]) elif Dist is dist.Uniform: low = torch.randn(shape) high = low + torch.randn(shape).exp() return Dist(low, high) else: params = { name: transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5) for name in UNIVARIATE_DISTS[Dist] } return Dist(**params)
def __init__(self, model, node_idx: int, k: int, x, edge_index, sharp: float = 0.01, splines: int = 6, sigmoid = True): #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') self.model = model.to(device) self.x = x.to(device) self.edge_index = edge_index.to(device) self.node_idx = node_idx self.k = k self.sharp = sharp self.subset, self.edge_index_adj, self.mapping, self.edge_mask_hard = k_hop_subgraph( self.node_idx, k, self.edge_index, relabel_nodes=True) self.x_adj = self.x[self.subset] self.device = device with torch.no_grad(): self.preds = model(self.x, self.edge_index_adj) self.N = self.edge_index_adj.size(1) self.base_dist = dist.Normal(torch.zeros(self.N).to(device), torch.ones(self.N).to(device)) self.splines = [] for i in range(splines): self.splines.append(T.spline(self.N).to(device)) self.flow_dist = dist.TransformedDistribution(self.base_dist,self.splines) self.sigmoid = sigmoid
def _condition(self, H): self.cond_transforms = [t.condition(H) for t in self.transforms] self.generative_flows = list( itertools.chain(*zip(self.cond_transforms, self.perms)))[:-1] self.normalizing_flows = self.generative_flows[::-1] return dist.TransformedDistribution(self.base_dist, self.generative_flows)
def test_deep_to_transformed(shape, dtype): loc = torch.tensor(0.0, requires_grad=True) scale = torch.tensor(1.0, requires_grad=False) a = torch.tensor(2.0, requires_grad=True) b = torch.tensor(3.0, requires_grad=False) d1 = dist.TransformedDistribution( dist.Normal(loc, scale), dist.transforms.AffineTransform(a, b) ) if shape is not None: d1 = d1.expand(shape) d2 = deep_to(d1, dtype) d2.log_prob(d2.sample().detach()) # smoke test assert type(d1) is type(d2) assert d2.event_shape == d1.event_shape assert d2.batch_shape == d1.batch_shape assert type(d1.base_dist) is type(d2.base_dist) assert len(d1.transforms) == len(d2.transforms) assert_equal(d1.base_dist.loc.to(dtype), d2.base_dist.loc) assert_equal(d1.base_dist.scale.to(dtype), d2.base_dist.scale) assert_equal(d1.transforms[0].loc.to(dtype), d2.transforms[0].loc) assert_equal(d1.transforms[0].scale.to(dtype), d2.transforms[0].scale) assert d2.base_dist.loc.requires_grad assert not d2.base_dist.scale.requires_grad assert d2.transforms[0].loc.requires_grad assert not d2.transforms[0].scale.requires_grad
def __call__(self, name, fn, obs): # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), init_dist, None) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: trans_dist, _ = self.trans("{}_trans".format(name), trans_dist.to_event(1), None) trans_dist = trans_dist.to_event(-1) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: obs_dist, obs = self.obs("{}_obs".format(name), obs_dist.to_event(1), obs) obs_dist = obs_dist.to_event(-1) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs
def guide(data): loc_eta = torch.randn(J, 1) # note that we initialize our scales to be pretty narrow scale_eta = 0.1 * torch.rand(J, 1) loc_mu = torch.randn(1) scale_mu = 0.1 * torch.rand(1) loc_logtau = torch.randn(1) scale_logtau = 0.1 * torch.rand(1) # register learnable params in the param store m_eta_param = pyro.param("loc_eta", loc_eta) s_eta_param = pyro.param("scale_eta", scale_eta, constraint=constraints.positive) m_mu_param = pyro.param("loc_mu", loc_mu) s_mu_param = pyro.param("scale_mu", scale_mu, constraint=constraints.positive) m_logtau_param = pyro.param("loc_logtau", loc_logtau) s_logtau_param = pyro.param("scale_logtau", scale_logtau, constraint=constraints.positive) # guide distributions dist_eta = dist.Normal(m_eta_param, s_eta_param) dist_mu = dist.Normal(m_mu_param, s_mu_param) dist_tau = dist.TransformedDistribution(dist.Normal(m_logtau_param, s_logtau_param), transforms=transforms.ExpTransform()) pyro.sample('eta', dist_eta) pyro.sample('mu', dist_mu) pyro.sample('tau', dist_tau)
def get_posterior(self, *args, **kwargs): """ Returns the posterior distribution. """ base_dist = self.get_base_dist() transform = self.get_transform(*args, **kwargs) return dist.TransformedDistribution(base_dist, transform)
def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): if obs is not None: obs = obs.transpose(-1, -2).unsqueeze(-1) hmm, obs = self(name, fn.base_dist.to_event(1), obs) hmm = dist.IndependentHMM(hmm.to_event(-1)) if obs is not None: obs = obs.squeeze(-1).transpose(-1, -2) return hmm, obs # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), self._wrap(init_dist, event_dim - 1), None) init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration, )) trans_dist, _ = self.trans("{}_trans".format(name), self._wrap(trans_dist, event_dim), None) trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration, )) obs_dist, obs = self.obs("{}_obs".format(name), self._wrap(obs_dist, event_dim), obs) obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs
def generate(self, x, num_particles): z_dist = self.encoder.predict(x) z = z_dist.sample() x_pred_dist = self.decoder.predict(z) x_base_dist = dist.Normal( torch.zeros_like(x, requires_grad=False).view(x.shape[0], -1), torch.ones_like(x, requires_grad=False).view(x.shape[0], -1), ).to_event(1) if 'normal' in self.decoder_output or \ self.decoder_output == 'deepvar' or \ self.decoder_output == 'deepmean': transform = AffineTransform(x_pred_dist.mean, x_pred_dist.stddev, 1) elif self.decoder_output == 'low_rank_mvn': # print(x_pred_dist.loc.shape) # print(x_pred_dist.loc) # print(x_pred_dist.scale_tril.shape) # print(x_pred_dist.scale_tril) transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) else: raise Exception('Unknown decoder output') x_dist = dist.TransformedDistribution(x_base_dist, ComposeTransform([transform])) recons = [] for i in range(num_particles): recon = pyro.sample('x', x_dist).view(x.shape[0], self.shape, 3) recons.append(recon) return torch.stack(recons).mean(0)
def _test_shape(self, base_shape, make_flow): base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape flow = make_flow(input_dim=last_dim) sample = dist.TransformedDistribution(base_dist, [flow]).sample() assert sample.shape == base_shape
def _(d, batch_shape): base_dist = reshape_batch(d.base_dist, batch_shape) old_shape = d.base_dist.shape() new_shape = base_dist.shape() transforms = [ reshape_transform_batch(t, old_shape, new_shape) for t in d.transforms ] return dist.TransformedDistribution(base_dist, transforms)
def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): return pyro.sample( "x", dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]))
def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]) if event_shape: fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): with pyro.plate("particles", 200000): return pyro.sample("x", fn)
def model(self): """ Generative model of behavior with a NormalGamma prior over free model parameters. """ runs = self.runs # number of independent runs of experiment npar = self.npar # number of parameters # define hyper priors over model parameters a = param('a', ones(npar), constraint=constraints.positive) lam = param('lam', ones(npar), constraint=constraints.positive) tau = sample('tau', dist.Gamma(a, a / lam).to_event(1)) sig = 1 / torch.sqrt(tau) # each model parameter has a hyperprior defining group level mean m = param('m', zeros(npar)) s = param('s', ones(npar), constraint=constraints.positive) mu = sample('mu', dist.Normal(m, s * sig).to_event(1)) # define prior mean over model parametrs and subjects with plate('runs', runs): base_dist = dist.Normal(0., 1.).expand_by([npar]).to_event(1) transform = dist.transforms.AffineTransform(mu, sig) locs = sample('locs', dist.TransformedDistribution(base_dist, [transform])) if self.fixed_values: x = zeros(locs.shape[:-1] + (self.agent.npar, )) x[..., self.locs['fixed']] = self.values x[..., self.locs['free']] = locs else: x = locs self.agent.set_parameters(x) for b in range(self.nb): for t in range(self.nt): # update single trial offers = self.stimulus['offers'][b, t] self.agent.planning(b, t, offers) outcomes = self.stimulus['outcomes'][b, t] responses = self.responses[b, t] mask = self.stimulus['mask'][b, t] self.agent.update_beliefs(b, t, [responses, outcomes], mask=mask) mask = self.notnans[b, t] logits = self.agent.logits[-1] sample('obs_{}_{}'.format(b, t), dist.Categorical(logits=logits).mask(mask), obs=responses)
def model(self): """ Generative model of behavior with a hierarchical (horseshoe) prior over free model parameters. """ runs = self.runs # number of independent runs of experiment npar = self.npar # number of parameters # define hyper priors over model parameters. # each model parameter has a hyperpriors defining group level mean m = param('m', zeros(npar)) s = param('s', ones(npar), constraint=constraints.positive) mu = sample('mu', dist.Normal(m, s).to_event(1)) # define prior uncertanty over model parameters and subjects lam = param('lam', ones(1), constraint=constraints.positive) tau = sample('tau', dist.HalfCauchy(ones(npar)).to_event(1)) # define prior mean over model parametrs and subjects with plate('runs', runs): base_dist = dist.Normal(0., 1.).expand([npar]).to_event(1) transform = dist.transforms.AffineTransform(mu, lam * tau) locs = sample('locs', dist.TransformedDistribution(base_dist, [transform])) if self.fixed_values: x = zeros(runs, self.agent.npar) x[:, self.locs['fixed']] = self.values x[:, self.locs['free']] = locs else: x = locs self.agent.set_parameters(x) for b in range(self.nb): for t in range(self.nt): # update single trial offers = self.stimulus['offers'][b, t] self.agent.planning(b, t, offers) outcomes = self.stimulus['outcomes'][b, t] responses = self.responses[b, t] mask = self.stimulus['mask'][b, t] self.agent.update_beliefs(b, t, [responses, outcomes], mask=mask) logits = self.agent.logits[-1] sample('obs_{}_{}'.format(b, t), dist.Categorical(logits=logits).mask( self.notnans[b, t]), obs=responses)
def true_model(design): w1 = torch.tensor([-1., 1.]) w2 = torch.tensor([-.5, .5, -.5, .5, -.5, 2., -2., 2., -2., 0.]) w = torch.cat([w1, w2], dim=-1) k = torch.tensor(.1) response_mean = rmv(design, w) base_dist = dist.Normal(response_mean, torch.tensor(1.)).to_event(1) k = k.expand(response_mean.shape) transforms = [AffineTransform(loc=0., scale=k), SigmoidTransform()] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample("y", response_dist)
def get_posterior( self, name: str, prior: Distribution) -> Union[Distribution, torch.Tensor]: with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) affine = dist.transforms.AffineTransform( loc, scale, event_dim=transform.domain.event_dim, cache_size=1) posterior = dist.TransformedDistribution( prior, [transform.inv.with_cache(), affine, transform.with_cache()]) return posterior
def get_posterior(self, *args, **kwargs): """ Returns a diagonal Normal posterior distribution transformed by :class:`~pyro.distributions.iaf.InverseAutoregressiveFlow`. """ if self.latent_dim == 1: raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') if self.hidden_dim is None: self.hidden_dim = self.latent_dim iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(self.latent_dim, [self.hidden_dim])) pyro.module("{}_iaf".format(self.prefix), iaf) iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf]) return iaf_dist.to_event(1)
def get_posterior( self, name: str, prior: Distribution) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) posterior = dist.TransformedDistribution( dist.Normal(loc, scale).to_event(transform.domain.event_dim), transform.with_cache(), ) return posterior
def guide(self): """Approximate posterior for the Dirichlet process prior. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters kmax = self.kmax # maximum number of components gaa = param("ga_a", ones(1), constraint=constraints.positive) gar = param("ga_r", .1 * ones(1), constraint=constraints.positive) sample('alpha', dist.Gamma(gaa, gaa / gar)) gba = param("gb_beta_a", ones(kmax - 1), constraint=constraints.positive) gbb = param("gb_beta_b", ones(kmax - 1), constraint=constraints.positive) beta = sample("beta", dist.Beta(gba, gbb).to_event(1)) with plate('classes', kmax, dim=-2): m_mu = param('m_mu', zeros(kmax, npar)) st_mu = param('scale_tril_mu', torch.eye(npar).repeat(kmax, 1, 1), constraint=constraints.lower_cholesky) mu = sample("mu", dist.MultivariateNormal(m_mu, scale_tril=st_mu)) m_tau = param('m_hyp', zeros(kmax, npar)) st_tau = param('scale_tril_hyp', torch.eye(npar).repeat(kmax, 1, 1), constraint=constraints.lower_cholesky) mn = dist.MultivariateNormal(m_tau, scale_tril=st_tau) sample( "tau", dist.TransformedDistribution(mn, [dist.transforms.ExpTransform()])) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) class_probs = param('class_probs', ones(nsub, kmax) / kmax, constraint=constraints.simplex) with plate('subjects', nsub, dim=-2): sample('class', dist.Categorical(class_probs), infer={"enumerate": "parallel"}) sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def get_posterior(self, *args, **kwargs): """ Returns a diagonal Normal posterior distribution transformed by :class:`~pyro.distributions.transforms.iaf.InverseAutoregressiveFlow`. """ if self.latent_dim == 1: raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') if self.hidden_dim is None: self.hidden_dim = self.latent_dim if self.arn is None: self.arn = AutoRegressiveNN(self.latent_dim, [self.hidden_dim]) iaf = transforms.AffineAutoregressive(self.arn) iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf]) return iaf_dist
def test_overdispersed_asymptote(probs, overdispersion): total_count = 100000 # Check binomial_dist converges in distribution to LogitNormal. d1 = binomial_dist(total_count, probs) d2 = dist.TransformedDistribution( dist.Normal(math.log(probs / (1 - probs)), overdispersion), SigmoidTransform()) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion k = torch.arange(0., total_count + 1.) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.cdf(k / total_count) crps = (cdf1 - cdf2).pow(2).mean() assert crps < 0.02
def _test_shape(self, base_shape, transform): base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) sample = dist.TransformedDistribution(base_dist, [transform]).sample() assert sample.shape == base_shape batch_shape = base_shape[:len(base_shape) - transform.domain.event_dim] input_event_shape = base_shape[len(base_shape) - transform.domain.event_dim:] output_event_shape = base_shape[len(base_shape) - transform.codomain.event_dim:] output_shape = batch_shape + output_event_shape assert transform.forward_shape(input_event_shape) == output_event_shape assert transform.forward_shape(base_shape) == output_shape assert transform.inverse_shape(output_event_shape) == input_event_shape assert transform.inverse_shape(output_shape) == base_shape
def model(self, *args, **kargs): self.inv_softplus_sigma = pyro.param("inv_softplus_sigma", torch.ones(self.rank)) sigma = self.sigma #torch.nn.functional.softplus(self.inv_softplus_sigma) #base_dist = dist.Normal(torch.zeros(self.rank), torch.ones(self.rank)) # Pavel: introducing `sigma` in the IAF distribution makes training more # stable in tems of the scale of the distribution we are trying to learn base_dist = dist.Normal(torch.zeros(self.rank), sigma) ann = AutoRegressiveNN(self.rank, self.n_hid, skip_connections=True) iaf = dist.InverseAutoregressiveFlow(ann) iaf_module = pyro.module("my_iaf", iaf) iaf_dist = dist.TransformedDistribution(base_dist, [iaf]) self.t = pyro.sample("t", iaf_dist.to_event(1)) return self.t
def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" assert fn.event_dim >= -self.dim, ( "Cannot transform along batch dimension; " "try converting a batch dimension to an event dimension") # Draw noise from the base distribution. transform = DiscreteCosineTransform(dim=self.dim, cache_size=1) x_dct = pyro.sample("{}_dct".format(name), dist.TransformedDistribution(fn, transform)) # Differentiably transform. x = transform.inv(x_dct) # should be free due to transform cache # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=fn.event_dim) return new_fn, x