Example #1
0
def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape,
                                  trans_mvn_shape, obs_mat_shape,
                                  obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    scale_dist = random_gamma(scale_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)

    shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ),
                            trans_mat_shape, trans_mvn_shape, obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    mixing, final = d.filter(data)
    assert isinstance(mixing, dist.Gamma)
    assert mixing.batch_shape == d.batch_shape
    assert mixing.event_shape == ()
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim, )
Example #2
0
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps,
                                     hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    scale_dist = random_gamma(batch_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)
    obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussian-gammas with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gamma_gaussian_tensordot().
    T = num_steps
    init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist)
    trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    # compute log_prob of the joint student-t distribution
    expected_log_prob = logp.compound().log_prob(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
Example #3
0
def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim):
    gamma = random_gamma(batch_shape)
    mvn = random_mvn(batch_shape, dim)
    g = gamma_and_mvn_to_gamma_gaussian(gamma, mvn)
    value = mvn.sample(sample_shape)
    s = gamma.sample(sample_shape)
    actual_log_prob = g.log_density(value, s)

    s_log_prob = gamma.log_prob(s)
    scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    mvn_log_prob = dist.MultivariateNormal(
        mvn.loc, precision_matrix=scaled_prec).log_prob(value)
    expected_log_prob = s_log_prob + mvn_log_prob
    assert_close(actual_log_prob, expected_log_prob)