def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + hidden_dim) obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = mvn_to_gaussian(trans_dist) obs = mvn_to_gaussian(obs_dist) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim) logp_h += unrolled_obs.marginalize(right=T * obs_dim) expected_log_prob = logp_oh.log_density( unrolled_data) - logp_h.event_logsumexp() assert_close(actual_log_prob, expected_log_prob)
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim) actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist) expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4) check_expand(actual_dist, data)
def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape data = obs_dist.expand(shape).sample()[..., hidden_dim:] assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data)
def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): # Construct a block-diagonal obs dist, so observations are independent of hidden state. obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim) precision = obs_dist.precision_matrix precision[..., :hidden_dim, hidden_dim:] = 0 precision[..., hidden_dim:, :hidden_dim] = 0 obs_dist = dist.MultivariateNormal(obs_dist.loc, precision_matrix=precision) marginal_obs_dist = dist.MultivariateNormal( obs_dist.loc[..., hidden_dim:], precision_matrix=precision[..., hidden_dim:, hidden_dim:]) init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) expected_log_prob = marginal_obs_dist.log_prob(data).sum(-1) assert_close(actual_log_prob, expected_log_prob)