Example #1
0
def test_switching_linear_hmm_log_prob(exact, num_steps, hidden_dim, obs_dim,
                                       num_components):
    # This tests agreement between an SLDS and an HMM when all components
    # are identical, i.e. so latent can be marginalized out.
    torch.manual_seed(2)
    init_logits = torch.rand(num_components)
    init_mvn = random_mvn((), hidden_dim)
    trans_logits = torch.rand(num_components)
    trans_matrix = torch.randn(hidden_dim, hidden_dim)
    trans_mvn = random_mvn((), hidden_dim)
    obs_matrix = torch.randn(hidden_dim, obs_dim)
    obs_mvn = random_mvn((), obs_dim)

    expected_dist = GaussianHMM(init_mvn,
                                trans_matrix.expand(num_steps, -1, -1),
                                trans_mvn, obs_matrix, obs_mvn)
    actual_dist = SwitchingLinearHMM(init_logits,
                                     init_mvn,
                                     trans_logits,
                                     trans_matrix.expand(
                                         num_steps, num_components, -1, -1),
                                     trans_mvn,
                                     obs_matrix,
                                     obs_mvn,
                                     exact=exact)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    data = obs_mvn.sample(expected_dist.batch_shape + (num_steps, ))
    assert data.shape == expected_dist.shape()
    expected_log_prob = expected_dist.log_prob(data)
    assert expected_log_prob.shape == expected_dist.batch_shape
    actual_log_prob = actual_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=None)
Example #2
0
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape,
                                    trans_cat_shape, trans_mat_shape, trans_mvn_shape,
                                    obs_mat_shape, obs_mvn_shape):
    hidden_dim, obs_dim = obs_mat_shape[-2:]
    assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim)

    init_logits = torch.randn(init_cat_shape)
    init_mvn = random_mvn(init_mvn_shape, hidden_dim)
    trans_logits = torch.randn(trans_cat_shape)
    trans_matrix = torch.randn(trans_mat_shape)
    trans_mvn = random_mvn(trans_mvn_shape, hidden_dim)
    obs_matrix = torch.randn(obs_mat_shape)
    obs_mvn = random_mvn(obs_mvn_shape, obs_dim)

    init_shape = broadcast_shape(init_cat_shape, init_mvn_shape)
    shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                            trans_cat_shape[:-1],
                            trans_mat_shape[:-2],
                            trans_mvn_shape,
                            obs_mat_shape[:-2],
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-2], shape[-2:-1]
    expected_event_shape = time_shape + (obs_dim,)

    actual_dist = SwitchingLinearHMM(init_logits, init_mvn,
                                     trans_logits, trans_matrix, trans_mvn,
                                     obs_matrix, obs_mvn)
    assert actual_dist.event_shape == expected_event_shape
    assert actual_dist.batch_shape == expected_batch_shape

    data = obs_mvn.expand(shape).sample()[..., 0, :]
    actual_log_prob = actual_dist.log_prob(data)
    assert actual_log_prob.shape == expected_batch_shape
    check_expand(actual_dist, data)

    final_cat, final_mvn = actual_dist.filter(data)
    assert isinstance(final_cat, dist.Categorical)
    assert isinstance(final_mvn, dist.MultivariateNormal)
    assert final_cat.batch_shape == actual_dist.batch_shape
    assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
Example #3
0
def test_switching_linear_hmm_log_prob_alternating(exact, num_steps,
                                                   num_components):
    # This tests agreement between an SLDS and an HMM in the case that the two
    # SLDS discrete states alternate back and forth between 0 and 1 deterministically

    torch.manual_seed(0)

    hidden_dim = 4
    obs_dim = 3
    extra_components = num_components - 2

    init_logits = torch.tensor([float("-inf"), 0.0] +
                               extra_components * [float("-inf")])
    init_mvn = random_mvn((num_components, ), hidden_dim)

    left_logits = torch.tensor([0.0, float("-inf")] +
                               extra_components * [float("-inf")])
    right_logits = torch.tensor([float("-inf"), 0.0] +
                                extra_components * [float("-inf")])
    trans_logits = torch.stack([
        left_logits if t % 2 == 0 else right_logits for t in range(num_steps)
    ])
    trans_logits = trans_logits.unsqueeze(-2)

    hmm_trans_matrix = torch.randn(num_steps, hidden_dim, hidden_dim)
    switching_trans_matrix = hmm_trans_matrix.unsqueeze(-3).expand(
        -1, num_components, -1, -1)

    trans_mvn = random_mvn((
        num_steps,
        num_components,
    ), hidden_dim)
    hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim)
    switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand(
        -1, num_components, -1, -1)
    obs_mvn = random_mvn((num_steps, num_components), obs_dim)

    hmm_trans_mvn_loc = torch.empty(num_steps, hidden_dim)
    hmm_trans_mvn_cov = torch.empty(num_steps, hidden_dim, hidden_dim)
    hmm_obs_mvn_loc = torch.empty(num_steps, obs_dim)
    hmm_obs_mvn_cov = torch.empty(num_steps, obs_dim, obs_dim)

    for t in range(num_steps):
        # select relevant bits for hmm given deterministic dynamics in discrete space
        s = t % 2  # 0, 1, 0, 1, ...
        hmm_trans_mvn_loc[t] = trans_mvn.loc[t, s]
        hmm_trans_mvn_cov[t] = trans_mvn.covariance_matrix[t, s]
        hmm_obs_mvn_loc[t] = obs_mvn.loc[t, s]
        hmm_obs_mvn_cov[t] = obs_mvn.covariance_matrix[t, s]

        # scramble matrices in places that should never be accessed given deterministic dynamics in discrete space
        s = 1 - (t % 2)  # 1, 0, 1, 0, ...
        switching_trans_matrix[t, s, :, :] = torch.rand(hidden_dim, hidden_dim)
        switching_obs_matrix[t, s, :, :] = torch.rand(hidden_dim, obs_dim)

    expected_dist = GaussianHMM(
        dist.MultivariateNormal(init_mvn.loc[1],
                                init_mvn.covariance_matrix[1]),
        hmm_trans_matrix,
        dist.MultivariateNormal(hmm_trans_mvn_loc,
                                hmm_trans_mvn_cov), hmm_obs_matrix,
        dist.MultivariateNormal(hmm_obs_mvn_loc, hmm_obs_mvn_cov))

    actual_dist = SwitchingLinearHMM(init_logits,
                                     init_mvn,
                                     trans_logits,
                                     switching_trans_matrix,
                                     trans_mvn,
                                     switching_obs_matrix,
                                     obs_mvn,
                                     exact=exact)

    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    data = obs_mvn.sample()[:, 0, :]
    assert data.shape == expected_dist.shape()
    expected_log_prob = expected_dist.log_prob(data)
    assert expected_log_prob.shape == expected_dist.batch_shape
    actual_log_prob = actual_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-2, rtol=None)