def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) scale_dist = random_gamma(scale_shape) d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) mixing, final = d.filter(data) assert isinstance(mixing, dist.Gamma) assert mixing.batch_shape == d.batch_shape assert mixing.event_shape == () assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim, )
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) scale_dist = random_gamma(batch_shape) d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussian-gammas with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gamma_gaussian_tensordot(). T = num_steps init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist) trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) # compute log_prob of the joint student-t distribution expected_log_prob = logp.compound().log_prob(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim): gamma = random_gamma(batch_shape) mvn = random_mvn(batch_shape, dim) g = gamma_and_mvn_to_gamma_gaussian(gamma, mvn) value = mvn.sample(sample_shape) s = gamma.sample(sample_shape) actual_log_prob = g.log_density(value, s) s_log_prob = gamma.log_prob(s) scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1) mvn_log_prob = dist.MultivariateNormal( mvn.loc, precision_matrix=scaled_prec).log_prob(value) expected_log_prob = s_log_prob + mvn_log_prob assert_close(actual_log_prob, expected_log_prob)