def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape, event_shape, state_dim): init_logits = torch.randn(init_shape + (state_dim, )) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim, ) + event_shape) obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape)) data = obs_dist.sample()[(slice(None), ) * len(obs_shape) + (0, )] if not ok: with pytest.raises(ValueError): d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) d.log_prob(data) return d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert d.support.event_dim == d.event_dim actual = d.log_prob(data) expected_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1]) assert actual.shape == expected_shape check_expand(d, data) final = d.filter(data) assert isinstance(final, dist.Categorical) assert final.batch_shape == d.batch_shape assert final.event_shape == () assert final.support.upper_bound == state_dim - 1
def test_discrete_hmm_diag_normal(num_steps): state_dim = 3 event_size = 2 init_logits = torch.randn(state_dim) trans_logits = torch.randn(num_steps, state_dim, state_dim) loc = torch.randn(num_steps, state_dim, event_size) scale = torch.randn(num_steps, state_dim, event_size).exp() obs_dist = dist.Normal(loc, scale).to_event(1) d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) data = obs_dist.sample()[:, 0] actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) # Check loss against TraceEnum_ELBO. @config_enumerate def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): x = pyro.sample( "x_{}".format(t), dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) pyro.sample("obs_{}".format(t), dist.Normal( Vindex(loc)[..., t, x, :], Vindex(scale)[..., t, x, :]).to_event(1), obs=data[..., t, :]) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss)
def test_discrete_hmm_categorical(num_steps): state_dim = 3 obs_dim = 4 init_logits = torch.randn(state_dim) trans_logits = torch.randn(num_steps, state_dim, state_dim) obs_dist = dist.Categorical( logits=torch.randn(num_steps, state_dim, obs_dim)) d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) data = dist.Categorical(logits=torch.zeros(num_steps, obs_dim)).sample() actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) # Check loss against TraceEnum_ELBO. @config_enumerate def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): x = pyro.sample( "x_{}".format(t), dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) pyro.sample("obs_{}".format(t), dist.Categorical(logits=Vindex(obs_dist.logits)[..., t, x, :]), obs=data[..., t]) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss)
def model_7(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch: lengths = lengths[batch] y = sequences[batch] if args.jit else sequences[batch, :lengths.max()] x = torch.arange(args.hidden_dim) t = torch.arange(y.size(1)) init_logits = torch.full((args.hidden_dim, ), -float("inf")) init_logits[0] = 0 trans_logits = probs_x.log() with ignore_jit_warnings(): obs_dist = dist.Bernoulli( logits=tones_generator(x, y.unsqueeze(-2))).to_event(1) obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1)) hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) pyro.sample("y", hmm_dist, obs=y)
def test_discrete_hmm_distribution(): init_probs = torch.tensor([0.9, 0.1]) trans_probs = torch.tensor([ [[0.9, 0.1], [0.1, 0.9]], # noisy identity [[0.1, 0.9], [0.9, 0.1]], # noisy flip ]) obs_dist = dist.Normal(torch.tensor([0.0, 1.0]), 0.1) hmm = dist.DiscreteHMM(init_probs.log(), trans_probs.log(), obs_dist) actual = hmm.sample([1000000]).mean(0) expected = torch.tensor([0.1 * 0.9 + 0.9 * 0.1, 0.9**3 + 3 * 0.9 * 0.1**2]) assert_close(actual, expected, atol=1e-3)
def test_discrete_hmm_homogeneous_trick(init_shape, trans_shape, obs_shape, event_shape, state_dim, num_steps): batch_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1]) init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim,) + event_shape) obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape)) d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert d.event_shape == (1,) + event_shape data = obs_dist.expand(batch_shape + (num_steps, state_dim)).sample() data = data[(slice(None),) * (len(batch_shape) + 1) + (0,)] assert data.shape == batch_shape + (num_steps,) + event_shape actual = d.log_prob(data) assert actual.shape == batch_shape
def test_discrete_normal_log_prob(init_shape, trans_shape, obs_shape, state_dim): init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) loc = torch.randn(obs_shape + (state_dim,)) scale = torch.randn(obs_shape + (state_dim,)).exp() obs_dist = dist.Normal(loc, scale) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim,)).sample() data = data[(slice(None),) * len(batch_shape) + (0,)] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, rtol=5e-5) check_expand(actual_dist, data)
def test_discrete_categorical_log_prob(init_shape, trans_shape, obs_shape, state_dim): obs_dim = 4 init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim, obs_dim)) obs_dist = dist.Categorical(logits=obs_logits) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim,)).sample() data = data[(slice(None),) * len(batch_shape) + (0,)] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob) check_expand(actual_dist, data)
def test_discrete_mvn_log_prob(init_shape, trans_shape, obs_shape, state_dim): event_size = 4 init_logits = torch.randn(init_shape + (state_dim, )) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) loc = torch.randn(obs_shape + (state_dim, event_size)) cov = torch.randn(obs_shape + (state_dim, event_size, 2 * event_size)) cov = cov.matmul(cov.transpose(-1, -2)) scale_tril = torch.cholesky(cov) obs_dist = dist.MultivariateNormal(loc, scale_tril=scale_tril) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1, ), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim, )).sample() data = data[(slice(None), ) * len(batch_shape) + (0, )] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob) check_expand(actual_dist, data)