Exemplo n.º 1
0
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps,
                                     hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    scale_dist = random_gamma(batch_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)
    obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussian-gammas with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gamma_gaussian_tensordot().
    T = num_steps
    init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist)
    trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    # compute log_prob of the joint student-t distribution
    expected_log_prob = logp.compound().log_prob(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
Exemplo n.º 2
0
 def __init__(self,
              scale_dist,
              initial_dist,
              transition_matrix,
              transition_dist,
              observation_matrix,
              observation_dist,
              validate_args=None):
     assert isinstance(scale_dist, Gamma)
     assert isinstance(initial_dist, MultivariateNormal)
     assert isinstance(transition_matrix, torch.Tensor)
     assert isinstance(transition_dist, MultivariateNormal)
     assert isinstance(observation_matrix, torch.Tensor)
     assert isinstance(observation_dist, MultivariateNormal)
     hidden_dim, obs_dim = observation_matrix.shape[-2:]
     assert initial_dist.event_shape == (hidden_dim, )
     assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
     assert transition_dist.event_shape == (hidden_dim, )
     assert observation_dist.event_shape == (obs_dim, )
     shape = broadcast_shape(scale_dist.batch_shape + (1, ),
                             initial_dist.batch_shape + (1, ),
                             transition_matrix.shape[:-2],
                             transition_dist.batch_shape,
                             observation_matrix.shape[:-2],
                             observation_dist.batch_shape)
     batch_shape, time_shape = shape[:-1], shape[-1:]
     event_shape = time_shape + (obs_dim, )
     super(GammaGaussianHMM, self).__init__(batch_shape,
                                            event_shape,
                                            validate_args=validate_args)
     self.hidden_dim = hidden_dim
     self.obs_dim = obs_dim
     self._init = gamma_and_mvn_to_gamma_gaussian(scale_dist, initial_dist)
     self._trans = matrix_and_mvn_to_gamma_gaussian(transition_matrix,
                                                    transition_dist)
     self._obs = matrix_and_mvn_to_gamma_gaussian(observation_matrix,
                                                  observation_dist)
Exemplo n.º 3
0
def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_dim,
                                          y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    g = matrix_and_mvn_to_gamma_gaussian(matrix, y_mvn)
    xy = torch.randn(sample_shape + batch_shape + (x_dim + y_dim, ))
    s = torch.rand(sample_shape + batch_shape)
    actual_log_prob = g.log_density(xy, s)

    x, y = xy[..., :x_dim], xy[..., x_dim:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    loc = y_pred + y_mvn.loc
    scaled_prec = y_mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    expected_log_prob = dist.MultivariateNormal(
        loc, precision_matrix=scaled_prec).log_prob(y)
    assert_close(actual_log_prob, expected_log_prob)