def model(self, zero_data, covariates): duration,data_dim = zero_data.shape feature_dim = covariates.size(-1) drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2)) drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5)) with pyro.plate("item", data_dim, dim=-2): bias = pyro.sample("bias", dist.Normal(5.0, 10.0)) with pyro.plate('features',feature_dim,dim=-1): weight = pyro.sample("weight", dist.Normal(0.0, 5)) # We'll sample a time-global scale parameter outside the time plate, # then time-local iid noise inside the time plate. with self.time_plate: # We combine two different reparameterizers: the inner SymmetricStableReparam # is needed for the Stable site, and the outer LocScaleReparam is optional but # appears to improve inference. with poutine.reparam(config={"drift": LocScaleReparam()}): with poutine.reparam(config={"drift": SymmetricStableReparam()}): drift = pyro.sample("drift", dist.Stable(drift_stability, 0, drift_scale)) motion = drift.cumsum(-1) # A Brownian motion. # The prediction now includes three terms. regression = torch.matmul(covariates, weight[...,None]) prediction = motion + bias + regression.sum(axis=-1) prediction = prediction.unsqueeze(-1).transpose(-1, -3) # Finally we can construct a noise distribution. # We will share parameters across all time series. obs_scale = pyro.sample("obs_scale", dist.LogNormal(-5, 5)) noise_dist = dist.Normal(loc=0.0, scale=obs_scale.unsqueeze(-1)) self.predict(noise_dist, prediction)
def __init__(self, model, data, covariates=None, *, num_warmup=1000, num_samples=1000, num_chains=1, time_reparam=None, dense_mass=False, jit_compile=False, max_tree_depth=10): assert data.size(-2) == covariates.size(-2) super().__init__() if time_reparam == "haar": model = poutine.reparam(model, time_reparam_haar) elif time_reparam == "dct": model = poutine.reparam(model, time_reparam_dct) elif time_reparam is not None: raise ValueError("unknown time_reparam: {}".format(time_reparam)) self.model = model max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {}) self.max_plate_nesting = max(max_plate_nesting, 1) # force a time plate kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True, max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting) mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(data, covariates) # conditions to compute rhat if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2): mcmc.summary() # inspect the model with particles plate = 1, so that we can reshape samples to # add any missing plate dim in front. with poutine.trace() as tr: with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1): model(data, covariates) self._trace = tr.trace self._samples = mcmc.get_samples() self._num_samples = num_samples * num_chains for name, node in list(self._trace.nodes.items()): if name not in self._samples: del self._trace.nodes[name]
def test_stable(Reparam, shape): stability = torch.empty(shape).uniform_(1.5, 2.).requires_grad_() skew = torch.empty(shape).uniform_(-0.5, 0.5).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() params = [stability, skew, scale, loc] def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 100000): return pyro.sample("x", dist.Stable(stability, skew, scale, loc)) value = model() expected_moments = get_moments(value) reparam_model = poutine.reparam(model, {"x": Reparam()}) trace = poutine.trace(reparam_model).get_trace() if Reparam is LatentStableReparam: assert isinstance(trace.nodes["x"]["fn"], MaskedDistribution) assert isinstance(trace.nodes["x"]["fn"].base_dist, dist.Delta) else: assert isinstance(trace.nodes["x"]["fn"], dist.Normal) trace.compute_log_prob() # smoke test only value = trace.nodes["x"]["value"] actual_moments = get_moments(value) assert_close(actual_moments, expected_moments, atol=0.05) for actual_m, expected_m in zip(actual_moments, expected_moments): expected_grads = grad(expected_m.sum(), params, retain_graph=True) actual_grads = grad(actual_m.sum(), params, retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.2) assert_close(actual_grads[1], expected_grads[1], atol=0.1) assert_close(actual_grads[2], expected_grads[2], atol=0.1) assert_close(actual_grads[3], expected_grads[3], atol=0.1)
def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_studentt(batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_studentt(batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = StudentTReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) assert tr.trace.nodes["x_init_gamma"]["fn"].event_shape == (hidden_dim, ) assert tr.trace.nodes["x_trans_gamma"]["fn"].event_shape == (duration, hidden_dim) assert tr.trace.nodes["x_obs_gamma"]["fn"].event_shape == (duration, obs_dim) tr.trace.compute_log_prob() # smoke test only
def model(self, zero_data, covariates): with pyro.plate("batch", len(zero_data), dim=-2): zero_data = pyro.subsample(zero_data, event_dim=1) covariates = pyro.subsample(covariates, event_dim=1) loc = zero_data[..., :1, :] scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1)) with self.time_plate: jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1)) prediction = jumps.cumsum(-2) duration, obs_dim = zero_data.shape[-2:] noise_dist = dist.LinearHMM( dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), duration=duration, ) rep = StableReparam() with poutine.reparam( config={"residual": LinearHMMReparam(rep, rep, rep)}): self.predict(noise_dist, prediction)
def test_beta_binomial_dependent_sample(): total = 10 counts = dist.Binomial(total, 0.3).sample() concentration1 = torch.tensor(0.5) concentration0 = torch.tensor(1.5) prior = dist.Beta(concentration1, concentration0) posterior = dist.Beta(concentration1 + counts, concentration0 + total - counts) def model(counts): prob = pyro.sample("prob", prior) pyro.sample("counts", dist.Binomial(total, prob), obs=counts) reparam_model = poutine.reparam( model, { "prob": ConjugateReparam( lambda counts: dist.Beta(1 + counts, 1 + total - counts)), }, ) with poutine.trace() as tr, pyro.plate("particles", 10000): reparam_model(counts) samples = tr.trace.nodes["prob"]["value"] assert_close(samples.mean(), posterior.mean, atol=0.01) assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
def test_normal(shape, dim, flip): loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample( "x", dist.Normal(loc, scale).expand(shape).to_event(-dim)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) rep = HaarReparam(dim=dim, flip=flip) reparam_model = poutine.reparam(model, {"x": rep}) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x_haar"]["fn"], dist.TransformedDistribution) assert isinstance(trace.nodes["x"]["fn"], dist.Delta) value = trace.nodes["x"]["value"] actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.1) for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]): expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.05) assert_close(actual_grads[1], expected_grads[1], atol=0.05)
def test_beta_binomial_hmc(): num_samples = 1000 total = 10 counts = dist.Binomial(total, 0.3).sample() concentration1 = torch.tensor(0.5) concentration0 = torch.tensor(1.5) prior = dist.Beta(concentration1, concentration0) likelihood = dist.Beta(1 + counts, 1 + total - counts) posterior = dist.Beta(concentration1 + counts, concentration0 + total - counts) def model(): prob = pyro.sample("prob", prior) pyro.sample("counts", dist.Binomial(total, prob), obs=counts) reparam_model = poutine.reparam(model, {"prob": ConjugateReparam(likelihood)}) kernel = HMC(reparam_model) samples = MCMC(kernel, num_samples, warmup_steps=0).run() pred = Predictive(reparam_model, samples, num_samples=num_samples) trace = pred.get_vectorized_trace() samples = trace.nodes["prob"]["value"] assert_close(samples.mean(), posterior.mean, atol=0.01) assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
def test_normal(dist_type, centered, shape): loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() if isinstance(centered, torch.Tensor): centered = centered.expand(shape) def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): if "dist_type" == "Normal": pyro.sample("x", dist.Normal(loc, scale)) else: pyro.sample("x", dist.StudentT(10.0, loc, scale)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) if "dist_type" == "Normal": reparam = LocScaleReparam() else: reparam = LocScaleReparam(shape_params=["df"]) reparam_model = poutine.reparam(model, {"x": reparam}) value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"] actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.1) for actual_m, expected_m in zip(actual_probe, expected_probe): expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.05) assert_close(actual_grads[1], expected_grads[1], atol=0.05)
def test_log_normal(shape): loc = torch.empty(shape).uniform_(-1, 1) scale = torch.empty(shape).uniform_(0.5, 1.5) def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): return pyro.sample( "x", dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()])) with poutine.trace() as tr: value = model() assert isinstance(tr.trace.nodes["x"]["fn"], dist.TransformedDistribution) expected_moments = get_moments(value) with poutine.reparam(config={"x": TransformReparam()}): with poutine.trace() as tr: value = model() assert isinstance(tr.trace.nodes["x"]["fn"], dist.Delta) actual_moments = get_moments(value) assert_close(actual_moments, expected_moments, atol=0.05)
def test_pyrocov_reparam(model, Guide, backend): T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), "local_time": torch.randn(T, P), "weekly_strains": torch.randn(T, P, S).exp().round(), } # Reparametrize the model. config = { "coef": LocScaleReparam(), "rate_loc": None if model is pyrocov_model else LocScaleReparam(), "rate": LocScaleReparam(), "init_loc": LocScaleReparam(), "init": LocScaleReparam(), } model = poutine.reparam(model, config) guide = Guide(model, backend=backend) svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): with xfail_if_not_implemented(): svi.step(dataset) guide(dataset) predictive = Predictive(model, guide=guide, num_samples=2) predictive(dataset)
def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable(batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_stable(batch_shape + (duration, obs_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = SymmetricStableReparam() if skew == 0 else StableReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def model(self, zero_data, covariates): data_dim = zero_data.size(-1) feature_dim = covariates.size(-1) bias = pyro.sample("bias", dist.Normal(5.0, 10.0).expand((data_dim,)).to_event(1)) weight = pyro.sample("weight", dist.Normal(0.0, 5).expand((feature_dim,)).to_event(1)) # We'll sample a time-global scale parameter outside the time plate, # then time-local iid noise inside the time plate. drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5.0).expand((1,)).to_event(1)) with self.time_plate: # We'll use a reparameterizer to improve variational fit. The model would still be # correct if you removed this context manager, but the fit appears to be worse. with poutine.reparam(config={"drift": LocScaleReparam()}): drift = pyro.sample("drift", dist.Normal(zero_data.double(), drift_scale.double()).to_event(1)) # After we sample the iid "drift" noise we can combine it in any time-dependent way. # It is important to keep everything inside the plate independent and apply dependent # transforms outside the plate. motion = drift.cumsum(-2) # A Brownian motion. # The prediction now includes three terms. prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True) assert prediction.shape[-2:] == zero_data.shape # Construct the noise distribution and predict. noise_scale = pyro.sample("noise_scale", dist.LogNormal(0.0, 1.0).expand((1,)).to_event(1)) noise_dist = dist.Normal(0, noise_scale) self.predict(noise_dist, prediction)
def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim): init_dist = random_stable((hidden_dim, ), stability, skew=skew).to_event(1) trans_mat = torch.randn(duration, hidden_dim, hidden_dim) trans_dist = random_stable((duration, hidden_dim), stability, skew=skew).to_event(1) obs_mat = torch.randn(duration, hidden_dim, obs_dim) obs_dist = random_stable((duration, obs_dim), stability, skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) num_samples = 200000 expected_samples = hmm.sample([num_samples ]).reshape(num_samples, duration * obs_dim) expected_loc, expected_scale, expected_corr = get_hmm_moments( expected_samples) rep = SymmetricStableReparam() if skew == 0 else StableReparam() with pyro.plate("samples", num_samples): with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): actual_samples = pyro.sample("x", hmm).reshape(num_samples, duration * obs_dim) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05) assert_close(actual_corr, expected_corr, atol=0.01)
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = dist.LogNormal( torch.randn(batch_shape + (duration, obs_dim)), torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam()}): model(data) fn = tr.trace.nodes["x"]["fn"] assert isinstance(fn, dist.TransformedDistribution) assert isinstance(fn.base_dist, dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_log_normal(batch_shape, event_shape): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1, 1) scale = torch.empty(shape).uniform_(0.5, 1.5) def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]) if event_shape: fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): with pyro.plate("particles", 200000): return pyro.sample("x", fn) with poutine.trace() as tr: value = model() assert isinstance(tr.trace.nodes["x"]["fn"], (dist.TransformedDistribution, dist.Independent)) expected_moments = get_moments(value) with poutine.reparam(config={"x": TransformReparam()}): with poutine.trace() as tr: value = model() assert isinstance(tr.trace.nodes["x"]["fn"], (dist.Delta, dist.MaskedDistribution)) actual_moments = get_moments(value) assert_close(actual_moments, expected_moments, atol=0.05)
def test_symmetric_stable(shape): stability = torch.empty(shape).uniform_(1.6, 1.9).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() params = [stability, scale, loc] def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): return pyro.sample("x", dist.Stable(stability, 0, scale, loc)) value = model() expected_moments = get_moments(value) reparam_model = poutine.reparam(model, {"x": SymmetricStableReparam()}) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x"]["fn"], dist.Normal) trace.compute_log_prob() # smoke test only value = trace.nodes["x"]["value"] actual_moments = get_moments(value) assert_close(actual_moments, expected_moments, atol=0.05) for actual_m, expected_m in zip(actual_moments, expected_moments): expected_grads = grad(expected_m.sum(), params, retain_graph=True) actual_grads = grad(actual_m.sum(), params, retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.2) assert_close(actual_grads[1], expected_grads[1], atol=0.1) assert_close(actual_grads[2], expected_grads[2], atol=0.1)
def test_stable_hmm_smoke(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_stable(batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) trans_dist = random_stable(batch_shape + (num_steps, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) obs_dist = random_stable(batch_shape + (num_steps, obs_dim)).to_event(1) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) def model(data): hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", hmm) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) # Test that we can combine these two reparameterizers. reparam_model = poutine.reparam( model, { "z": LinearHMMReparam(StableReparam(), StableReparam(), StableReparam()), }, ) reparam_model = poutine.reparam( reparam_model, { "z": ConjugateReparam(dist.Normal(data, 1).to_event(2)), }, ) reparam_guide = AutoDiagonalNormal( reparam_model) # Models auxiliary variables. # Smoke test only. elbo = Trace_ELBO(num_particles=5, vectorize_particles=True) loss = elbo.differentiable_loss(reparam_model, reparam_guide, data) params = [trans_mat, obs_mat] torch.autograd.grad(loss, params, retain_graph=True)
def test_distribution(df, loc, scale): def model(): with pyro.plate("particles", 20000): return pyro.sample("x", dist.StudentT(df, loc, scale)) expected = model() with poutine.reparam(config={"x": StudentTReparam()}): actual = model() assert ks_2samp(expected, actual).pvalue > 0.05
def test_sphere_reparam_ok(auto_class, init_loc_fn): def model(): x = pyro.sample("x", dist.Normal(0., 1.).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1., 0.])) model = poutine.reparam(model, {"y": ProjectedNormalReparam()}) guide = auto_class(model) poutine.trace(guide).get_trace().compute_log_prob()
def model(self, zero_data, covariates): duration, data_dim = zero_data.shape # Let's model each time series as a Levy stable process, and share process parameters # across time series. To do that in Pyro, we'll declare the shared random variables # outside of the "origin" plate: drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2)) drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5)) with pyro.plate("origin", data_dim, dim=-2): # Now inside of the origin plate we sample drift and seasonal components. # All the time series inside the "origin" plate are independent, # given the drift parameters above. with self.time_plate: # We combine two different reparameterizers: the inner SymmetricStableReparam # is needed for the Stable site, and the outer LocScaleReparam is optional but # appears to improve inference. with poutine.reparam(config={"drift": LocScaleReparam()}): with poutine.reparam(config={"drift": SymmetricStableReparam()}): drift = pyro.sample("drift", dist.Stable(drift_stability, 0, drift_scale)) with pyro.plate("hour_of_week", 24 * 7, dim=-1): seasonal = pyro.sample("seasonal", dist.Normal(0, 5)) # Now outside of the time plate we can perform time-dependent operations like # integrating over time. This allows us to create a motion with slow drift. seasonal = periodic_repeat(seasonal, duration, dim=-1) motion = drift.cumsum(dim=-1) # A Levy stable motion to model shocks. prediction = motion + seasonal # Next we do some reshaping. Pyro's forecasting framework assumes all data is # multivariate of shape (duration, data_dim), but the above code uses an "origins" # plate that is left of the time_plate. Our prediction starts off with shape assert prediction.shape[-2:] == (data_dim, duration) # We need to swap those dimensions but keep the -2 dimension intact, in case Pyro # adds sample dimensions to the left of that. prediction = prediction.unsqueeze(-1).transpose(-1, -3) assert prediction.shape[-3:] == (1, duration, data_dim), prediction.shape # Finally we can construct a noise distribution. # We will share parameters across all time series. obs_scale = pyro.sample("obs_scale", dist.LogNormal(-5, 5)) noise_dist = dist.Normal(0, obs_scale.unsqueeze(-1)) self.predict(noise_dist, prediction)
def reparam(self, model): """ Wrap a model with ``poutine.reparam``. """ # Transform to Haar coordinates. config = {} for name, dim in self.dims.items(): config[name] = HaarReparam(dim=dim, flip=True) model = poutine.reparam(model, config) if self.split: # Split into low- and high-frequency parts. splits = [self.split, self.duration - self.split] config = {} for name, dim in self.dims.items(): config[name + "_haar"] = SplitReparam(splits, dim=dim) model = poutine.reparam(model, config) return model
def test_normal_auto(centered): strategy = AutoReparam(centered=centered) model = strategy(normal_model) actual = trace_name_is_observed(model) if centered == 1.0: # i.e. no decentering expected = [ ("a", False), ("b_base", False), ("b", True), ("c", False), ("d_base", False), ("d", True), ("e", False), ("f_base", False), ("f", True), ("g", True), ("h", True), ("i", False), ("j_base", False), ("j", True), ] else: expected = [ ("a_decentered", False), ("a", True), ("b_base_decentered", False), ("b_base", True), ("b", True), ("c_decentered", False), ("c", True), ("d_base_decentered", False), ("d_base", True), ("d", True), ("e_decentered", False), ("e", True), ("f_base_decentered", False), ("f_base", True), ("f", True), ("g", True), ("h", True), ("i_decentered", False), ("i", True), ("j_base_decentered", False), ("j_base", True), ("j", True), ] assert actual == expected # Also check that the config dict has been constructed. config = strategy.config assert isinstance(config, dict) model = poutine.reparam(normal_model, config) actual = trace_name_is_observed(model) assert actual == expected
def test_distribution(stability, skew, Reparam): if Reparam is SymmetricStableReparam and (skew != 0 or stability == 2): pytest.skip() if stability == 2 and skew in (-1, 1): pytest.skip() def model(): with pyro.plate("particles", 20000): return pyro.sample("x", dist.Stable(stability, skew)) expected = model() with poutine.reparam(config={"x": Reparam()}): actual = model() assert ks_2samp(expected, actual).pvalue > 0.05
def model(self, zero_data, covariates): data_dim = zero_data.size(-1) feature_dim = covariates.size(-1) bias = pyro.sample("bias", dist.Normal(5.0, 10.0).expand((data_dim,)).to_event(1)) weight = pyro.sample("weight", dist.Normal(0.0, 5).expand((feature_dim,)).to_event(1)) # We'll sample a time-global scale parameter outside the time plate, # then time-local iid noise inside the time plate. drift_scale = pyro.sample("drift_scale", dist.LogNormal(0, 5.0).expand((1,)).to_event(1)) with self.time_plate: # We'll use a reparameterizer to improve variational fit. The model would still be # correct if you removed this context manager, but the fit appears to be worse. with poutine.reparam(config={"drift": LocScaleReparam()}): drift = pyro.sample("drift", dist.Normal(zero_data.double(), drift_scale.double()).to_event(1)) # After we sample the iid "drift" noise we can combine it in any time-dependent way. # It is important to keep everything inside the plate independent and apply dependent # transforms outside the plate. motion = drift.cumsum(-2) # A Brownian motion. # The prediction now includes three terms. prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True) assert prediction.shape[-2:] == zero_data.shape # The next part of the model creates a likelihood or noise distribution. # Again we'll be Bayesian and write this as a probabilistic program with # priors over parameters. stability = pyro.sample("noise_stability", dist.Uniform(1, 2).expand((1,)).to_event(1)) skew = pyro.sample("noise_skew", dist.Uniform(-1, 1).expand((1,)).to_event(1)) scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand((1,)).to_event(1)) noise_dist = dist.Stable(stability, skew, scale) # We need to use a reparameterizer to handle the Stable distribution. # Note "residual" is the name of Pyro's internal sample site in self.predict(). with poutine.reparam(config={"residual": StableReparam()}): self.predict(noise_dist, prediction)
def __call__(self, msg_or_fn: Union[dict, Callable]): """ Strategies can be used as decorators to reparametrize a model. :param msg_or_fn: Public use: a model to be decorated. (Internal use: a site to be configured for reparametrization). """ if isinstance(msg_or_fn, dict): # Internal use during configuration. msg = msg_or_fn name = msg["name"] if name in self.config: return self.config[name] result = self.configure(msg) self.config[name] = result return result else: # Public use as a decorator or handler. fn = msg_or_fn return poutine.reparam(fn, self)
def test_uniform(shape, dim, smooth): def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample("x", dist.Uniform(0, 1).expand(shape).to_event(-dim)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) reparam_model = poutine.reparam( model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)}) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution) assert isinstance(trace.nodes["x"]["fn"], dist.Delta) value = trace.nodes["x"]["value"] actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.1)
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) likelihood = dist.Normal(data, 1).to_event(2) posterior, log_normalizer = prior.conjugate_update(likelihood) def model(data): with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", prior) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) def guide(data): with pyro.plate_stack("plates", batch_shape): pyro.sample("z", posterior) reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)}) def reparam_guide(data): pass elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True) expected_loss = elbo.differentiable_loss(model, guide, data) actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data) assert_close(actual_loss, expected_loss, atol=0.01) params = [trans_mat, obs_mat] expected_grads = torch.autograd.grad(expected_loss, params, retain_graph=True) actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True) for a, e in zip(actual_grads, expected_grads): assert_close(a, e, rtol=0.01)
def check_init_reparam(model, reparam): assert isinstance(reparam, Reparam) with poutine.block(): init_value = model() with InitMessenger(init_to_value(values={"x": init_value})): # Sanity check without reparametrizing. actual = model() assert_close(actual, init_value) # Check with reparametrizing. with poutine.reparam(config={"x": reparam}): with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always", category=RuntimeWarning) actual = model() for w in ws: if w.category == RuntimeWarning and "falling back to default" in str( w): pytest.skip("overwriting initial value") else: warnings.warn(str(w.message), category=w.category) assert_close(actual, init_value)
def test_projected_normal(shape, dim): concentration = torch.randn(shape + (dim,)).requires_grad_() def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 10000): pyro.sample("x", dist.ProjectedNormal(concentration)) value = poutine.trace(model).get_trace().nodes["x"]["value"] assert dist.ProjectedNormal.support.check(value).all() expected_probe = get_moments(value) reparam_model = poutine.reparam(model, {"x": ProjectedNormalReparam()}) value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"] assert dist.ProjectedNormal.support.check(value).all() actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.05) for actual_m, expected_m in zip(actual_probe, expected_probe): expected_grad = grad(expected_m, [concentration], retain_graph=True) actual_grad = grad(actual_m, [concentration], retain_graph=True) assert_close(actual_grad, expected_grad, atol=0.1)