Esempio n. 1
0
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps,
                               hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps, ),
                            hidden_dim + hidden_dim)
    obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = mvn_to_gaussian(trans_dist)
    obs = mvn_to_gaussian(obs_dist)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim)
    logp_h += unrolled_obs.marginalize(right=T * obs_dim)
    expected_log_prob = logp_oh.log_density(
        unrolled_data) - logp_h.event_logsumexp()
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 2
0
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim)
    obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim)

    actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist)
    expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
    check_expand(actual_dist, data)
Esempio n. 3
0
def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim)
    obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)

    shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim,)
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape

    data = obs_dist.expand(shape).sample()[..., hidden_dim:]
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)
Esempio n. 4
0
def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim):
    # Construct a block-diagonal obs dist, so observations are independent of hidden state.
    obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim)
    precision = obs_dist.precision_matrix
    precision[..., :hidden_dim, hidden_dim:] = 0
    precision[..., hidden_dim:, :hidden_dim] = 0
    obs_dist = dist.MultivariateNormal(obs_dist.loc, precision_matrix=precision)
    marginal_obs_dist = dist.MultivariateNormal(
        obs_dist.loc[..., hidden_dim:],
        precision_matrix=precision[..., hidden_dim:, hidden_dim:])

    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)
    expected_log_prob = marginal_obs_dist.log_prob(data).sum(-1)
    assert_close(actual_log_prob, expected_log_prob)