def test_stable_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim): stability = dist.Uniform(0, 2).sample() init_dist = random_stable(stability, init_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_stable(stability, trans_dist_shape + (hidden_dim, )).to_event(1) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_stable(stability, obs_dist_shape + (obs_dim, )).to_event(1) d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=4) shape = broadcast_shape(init_shape + (4, ), trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def test_studentt_hmm_shape( init_shape, trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim, ): init_dist = random_studentt(init_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_studentt(trans_dist_shape + (hidden_dim, )).to_event(1) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_studentt(obs_dist_shape + (obs_dim, )).to_event(1) d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape( init_shape + (1, ), trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def model(self, zero_data, covariates): with pyro.plate("batch", len(zero_data), dim=-2): zero_data = pyro.subsample(zero_data, event_dim=1) covariates = pyro.subsample(covariates, event_dim=1) loc = zero_data[..., :1, :] scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1)) with self.time_plate: jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1)) prediction = jumps.cumsum(-2) duration, obs_dim = zero_data.shape[-2:] noise_dist = dist.LinearHMM( dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Stable(1.9, 0).expand([obs_dim]).to_event(1), duration=duration, ) rep = StableReparam() with poutine.reparam( config={"residual": LinearHMMReparam(rep, rep, rep)}): self.predict(noise_dist, prediction)
def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim): init_dist = random_stable((hidden_dim, ), stability, skew=skew).to_event(1) trans_mat = torch.randn(duration, hidden_dim, hidden_dim) trans_dist = random_stable((duration, hidden_dim), stability, skew=skew).to_event(1) obs_mat = torch.randn(duration, hidden_dim, obs_dim) obs_dist = random_stable((duration, obs_dim), stability, skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) num_samples = 200000 expected_samples = hmm.sample([num_samples ]).reshape(num_samples, duration * obs_dim) expected_loc, expected_scale, expected_corr = get_hmm_moments( expected_samples) rep = SymmetricStableReparam() if skew == 0 else StableReparam() with pyro.plate("samples", num_samples): with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): actual_samples = pyro.sample("x", hmm).reshape(num_samples, duration * obs_dim) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05) assert_close(actual_corr, expected_corr, atol=0.01)
def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable(batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_stable(batch_shape + (duration, obs_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = SymmetricStableReparam() if skew == 0 else StableReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_studentt(batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_studentt(batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() rep = StudentTReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM) assert tr.trace.nodes["x_init_gamma"]["fn"].event_shape == (hidden_dim, ) assert tr.trace.nodes["x_trans_gamma"]["fn"].event_shape == (duration, hidden_dim) assert tr.trace.nodes["x_obs_gamma"]["fn"].event_shape == (duration, obs_dim) tr.trace.compute_log_prob() # smoke test only
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = dist.LogNormal( torch.randn(batch_shape + (duration, obs_dim)), torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam()}): model(data) fn = tr.trace.nodes["x"]["fn"] assert isinstance(fn, dist.TransformedDistribution) assert isinstance(fn.base_dist, dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_init_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable( batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_stable( batch_shape + (duration, obs_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (duration, obs_dim) def model(): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm) rep = SymmetricStableReparam() if skew == 0 else StableReparam() check_init_reparam(model, LinearHMMReparam(rep, rep, rep))
def model(data): hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", hmm) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)
def test_stable_hmm_shape_error(batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) init_dist = random_stable(batch_shape + (hidden_dim, ), stability.unsqueeze(-1)).to_event(1) trans_mat = torch.randn(batch_shape + (1, hidden_dim, hidden_dim)) trans_dist = random_stable( batch_shape + (1, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim)) obs_dist = random_stable(batch_shape + (1, obs_dim), stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (1, obs_dim) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = torch.randn(duration, obs_dim) rep = StableReparam() with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): with pytest.raises(ValueError): model(data)