Beispiel #1
0
def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape,
                            event_shape, state_dim):
    init_logits = torch.randn(init_shape + (state_dim, ))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    obs_logits = torch.randn(obs_shape + (state_dim, ) + event_shape)
    obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape))
    data = obs_dist.sample()[(slice(None), ) * len(obs_shape) + (0, )]

    if not ok:
        with pytest.raises(ValueError):
            d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
            d.log_prob(data)
        return

    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert d.support.event_dim == d.event_dim

    actual = d.log_prob(data)
    expected_shape = broadcast_shape(init_shape, trans_shape[:-1],
                                     obs_shape[:-1])
    assert actual.shape == expected_shape
    check_expand(d, data)

    final = d.filter(data)
    assert isinstance(final, dist.Categorical)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == ()
    assert final.support.upper_bound == state_dim - 1
Beispiel #2
0
def test_discrete_hmm_diag_normal(num_steps):
    state_dim = 3
    event_size = 2
    init_logits = torch.randn(state_dim)
    trans_logits = torch.randn(num_steps, state_dim, state_dim)
    loc = torch.randn(num_steps, state_dim, event_size)
    scale = torch.randn(num_steps, state_dim, event_size).exp()
    obs_dist = dist.Normal(loc, scale).to_event(1)
    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    data = obs_dist.sample()[:, 0]
    actual = d.log_prob(data)
    assert actual.shape == d.batch_shape
    check_expand(d, data)

    # Check loss against TraceEnum_ELBO.
    @config_enumerate
    def model(data):
        x = pyro.sample("x_init", dist.Categorical(logits=init_logits))
        for t in range(num_steps):
            x = pyro.sample(
                "x_{}".format(t),
                dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]))
            pyro.sample("obs_{}".format(t),
                        dist.Normal(
                            Vindex(loc)[..., t, x, :],
                            Vindex(scale)[..., t, x, :]).to_event(1),
                        obs=data[..., t, :])

    expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data)
    actual_loss = -float(actual.sum())
    assert_close(actual_loss, expected_loss)
Beispiel #3
0
def test_discrete_hmm_categorical(num_steps):
    state_dim = 3
    obs_dim = 4
    init_logits = torch.randn(state_dim)
    trans_logits = torch.randn(num_steps, state_dim, state_dim)
    obs_dist = dist.Categorical(
        logits=torch.randn(num_steps, state_dim, obs_dim))
    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    data = dist.Categorical(logits=torch.zeros(num_steps, obs_dim)).sample()
    actual = d.log_prob(data)
    assert actual.shape == d.batch_shape
    check_expand(d, data)

    # Check loss against TraceEnum_ELBO.
    @config_enumerate
    def model(data):
        x = pyro.sample("x_init", dist.Categorical(logits=init_logits))
        for t in range(num_steps):
            x = pyro.sample(
                "x_{}".format(t),
                dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]))
            pyro.sample("obs_{}".format(t),
                        dist.Categorical(logits=Vindex(obs_dist.logits)[..., t,
                                                                        x, :]),
                        obs=data[..., t])

    expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data)
    actual_loss = -float(actual.sum())
    assert_close(actual_loss, expected_loss)
Beispiel #4
0
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
    with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch:
        lengths = lengths[batch]
        y = sequences[batch] if args.jit else sequences[batch, :lengths.max()]
        x = torch.arange(args.hidden_dim)
        t = torch.arange(y.size(1))
        init_logits = torch.full((args.hidden_dim, ), -float("inf"))
        init_logits[0] = 0
        trans_logits = probs_x.log()
        with ignore_jit_warnings():
            obs_dist = dist.Bernoulli(
                logits=tones_generator(x, y.unsqueeze(-2))).to_event(1)
            obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1))
            hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
        pyro.sample("y", hmm_dist, obs=y)
Beispiel #5
0
def test_discrete_hmm_distribution():
    init_probs = torch.tensor([0.9, 0.1])
    trans_probs = torch.tensor([
        [[0.9, 0.1], [0.1, 0.9]],  # noisy identity
        [[0.1, 0.9], [0.9, 0.1]],  # noisy flip
    ])
    obs_dist = dist.Normal(torch.tensor([0.0, 1.0]), 0.1)
    hmm = dist.DiscreteHMM(init_probs.log(), trans_probs.log(), obs_dist)
    actual = hmm.sample([1000000]).mean(0)
    expected = torch.tensor([0.1 * 0.9 + 0.9 * 0.1, 0.9**3 + 3 * 0.9 * 0.1**2])
    assert_close(actual, expected, atol=1e-3)
Beispiel #6
0
def test_discrete_hmm_homogeneous_trick(init_shape, trans_shape, obs_shape, event_shape, state_dim, num_steps):
    batch_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1])
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    obs_logits = torch.randn(obs_shape + (state_dim,) + event_shape)
    obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape))

    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert d.event_shape == (1,) + event_shape

    data = obs_dist.expand(batch_shape + (num_steps, state_dim)).sample()
    data = data[(slice(None),) * (len(batch_shape) + 1) + (0,)]
    assert data.shape == batch_shape + (num_steps,) + event_shape
    actual = d.log_prob(data)
    assert actual.shape == batch_shape
Beispiel #7
0
def test_discrete_normal_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    loc = torch.randn(obs_shape + (state_dim,))
    scale = torch.randn(obs_shape + (state_dim,)).exp()
    obs_dist = dist.Normal(loc, scale)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim,)).sample()
    data = data[(slice(None),) * len(batch_shape) + (0,)]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, rtol=5e-5)
    check_expand(actual_dist, data)
Beispiel #8
0
def test_discrete_categorical_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    obs_dim = 4
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    obs_logits = torch.randn(obs_shape + (state_dim, obs_dim))
    obs_dist = dist.Categorical(logits=obs_logits)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim,)).sample()
    data = data[(slice(None),) * len(batch_shape) + (0,)]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob)
    check_expand(actual_dist, data)
Beispiel #9
0
def test_discrete_mvn_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    event_size = 4
    init_logits = torch.randn(init_shape + (state_dim, ))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    loc = torch.randn(obs_shape + (state_dim, event_size))
    cov = torch.randn(obs_shape + (state_dim, event_size, 2 * event_size))
    cov = cov.matmul(cov.transpose(-1, -2))
    scale_tril = torch.cholesky(cov)
    obs_dist = dist.MultivariateNormal(loc, scale_tril=scale_tril)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1, ), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim, )).sample()
    data = data[(slice(None), ) * len(batch_shape) + (0, )]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob)
    check_expand(actual_dist, data)