def model(self, zero_data, covariates): with pyro.plate("batch", len(zero_data), dim=-2): zero_data = pyro.subsample(zero_data, event_dim=1) covariates = pyro.subsample(covariates, event_dim=1) loc = zero_data[..., :1, :] scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1)) with self.time_plate: jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1)) prediction = jumps.cumsum(-2) duration, obs_dim = zero_data.shape[-2:] noise_dist = dist.LinearHMM( dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), duration=duration, ) rep = StableReparam() with poutine.reparam( config={"residual": LinearHMMReparam(rep, rep, rep)}): self.predict(noise_dist, prediction)
def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim): init_dist = random_stable((hidden_dim, ), stability, skew=skew).to_event(1) trans_mat = torch.randn(duration, hidden_dim, hidden_dim) trans_dist = random_stable((duration, hidden_dim), stability, skew=skew).to_event(1) obs_mat = torch.randn(duration, hidden_dim, obs_dim) obs_dist = random_stable((duration, obs_dim), stability, skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) num_samples = 200000 expected_samples = hmm.sample([num_samples ]).reshape(num_samples, duration * obs_dim) expected_loc, expected_scale, expected_corr = get_hmm_moments( expected_samples) rep = SymmetricStableReparam() if skew == 0 else StableReparam() with pyro.plate("samples", num_samples): with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): actual_samples = pyro.sample("x", hmm).reshape(num_samples, duration * obs_dim) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05) assert_close(actual_corr, expected_corr, atol=0.01)
def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable(batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_stable(batch_shape + (duration, obs_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = SymmetricStableReparam() if skew == 0 else StableReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_studentt(batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_studentt(batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = StudentTReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) assert tr.trace.nodes["x_init_gamma"]["fn"].event_shape == (hidden_dim, ) assert tr.trace.nodes["x_trans_gamma"]["fn"].event_shape == (duration, hidden_dim) assert tr.trace.nodes["x_obs_gamma"]["fn"].event_shape == (duration, obs_dim) tr.trace.compute_log_prob() # smoke test only
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = dist.LogNormal( torch.randn(batch_shape + (duration, obs_dim)), torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam()}): model(data) fn = tr.trace.nodes["x"]["fn"] assert isinstance(fn, dist.TransformedDistribution) assert isinstance(fn.base_dist, dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_init_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable( batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_stable( batch_shape + (duration, obs_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (duration, obs_dim) def model(): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm) rep = SymmetricStableReparam() if skew == 0 else StableReparam() check_init_reparam(model, LinearHMMReparam(rep, rep, rep))
def test_stable_hmm_smoke(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_stable(batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) trans_dist = random_stable(batch_shape + (num_steps, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) obs_dist = random_stable(batch_shape + (num_steps, obs_dim)).to_event(1) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) def model(data): hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", hmm) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) # Test that we can combine these two reparameterizers. reparam_model = poutine.reparam( model, { "z": LinearHMMReparam(StableReparam(), StableReparam(), StableReparam()), }, ) reparam_model = poutine.reparam( reparam_model, { "z": ConjugateReparam(dist.Normal(data, 1).to_event(2)), }, ) reparam_guide = AutoDiagonalNormal( reparam_model) # Models auxiliary variables. # Smoke test only. elbo = Trace_ELBO(num_particles=5, vectorize_particles=True) loss = elbo.differentiable_loss(reparam_model, reparam_guide, data) params = [trans_mat, obs_mat] torch.autograd.grad(loss, params, retain_graph=True)
def test_stable_hmm_shape_error(batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1)).to_event(1) trans_mat = torch.randn(batch_shape + (1, hidden_dim, hidden_dim)) trans_dist = random_stable( batch_shape + (1, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim)) obs_dist = random_stable(batch_shape + (1, obs_dim), stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (1, obs_dim) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = torch.randn(duration, obs_dim) rep = StableReparam() with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): with pytest.raises(ValueError): model(data)