def test_switching_linear_hmm_log_prob(exact, num_steps, hidden_dim, obs_dim, num_components): # This tests agreement between an SLDS and an HMM when all components # are identical, i.e. so latent can be marginalized out. torch.manual_seed(2) init_logits = torch.rand(num_components) init_mvn = random_mvn((), hidden_dim) trans_logits = torch.rand(num_components) trans_matrix = torch.randn(hidden_dim, hidden_dim) trans_mvn = random_mvn((), hidden_dim) obs_matrix = torch.randn(hidden_dim, obs_dim) obs_mvn = random_mvn((), obs_dim) expected_dist = GaussianHMM(init_mvn, trans_matrix.expand(num_steps, -1, -1), trans_mvn, obs_matrix, obs_mvn) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, trans_matrix.expand( num_steps, num_components, -1, -1), trans_mvn, obs_matrix, obs_mvn, exact=exact) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape data = obs_mvn.sample(expected_dist.batch_shape + (num_steps, )) assert data.shape == expected_dist.shape() expected_log_prob = expected_dist.log_prob(data) assert expected_log_prob.shape == expected_dist.batch_shape actual_log_prob = actual_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=None)
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape, trans_cat_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape): hidden_dim, obs_dim = obs_mat_shape[-2:] assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim) init_logits = torch.randn(init_cat_shape) init_mvn = random_mvn(init_mvn_shape, hidden_dim) trans_logits = torch.randn(trans_cat_shape) trans_matrix = torch.randn(trans_mat_shape) trans_mvn = random_mvn(trans_mvn_shape, hidden_dim) obs_matrix = torch.randn(obs_mat_shape) obs_mvn = random_mvn(obs_mvn_shape, obs_dim) init_shape = broadcast_shape(init_cat_shape, init_mvn_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), trans_cat_shape[:-1], trans_mat_shape[:-2], trans_mvn_shape, obs_mat_shape[:-2], obs_mvn_shape) expected_batch_shape, time_shape = shape[:-2], shape[-2:-1] expected_event_shape = time_shape + (obs_dim,) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, trans_matrix, trans_mvn, obs_matrix, obs_mvn) assert actual_dist.event_shape == expected_event_shape assert actual_dist.batch_shape == expected_batch_shape data = obs_mvn.expand(shape).sample()[..., 0, :] actual_log_prob = actual_dist.log_prob(data) assert actual_log_prob.shape == expected_batch_shape check_expand(actual_dist, data) final_cat, final_mvn = actual_dist.filter(data) assert isinstance(final_cat, dist.Categorical) assert isinstance(final_mvn, dist.MultivariateNormal) assert final_cat.batch_shape == actual_dist.batch_shape assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_components): # This tests agreement between an SLDS and an HMM in the case that the two # SLDS discrete states alternate back and forth between 0 and 1 deterministically torch.manual_seed(0) hidden_dim = 4 obs_dim = 3 extra_components = num_components - 2 init_logits = torch.tensor([float("-inf"), 0.0] + extra_components * [float("-inf")]) init_mvn = random_mvn((num_components, ), hidden_dim) left_logits = torch.tensor([0.0, float("-inf")] + extra_components * [float("-inf")]) right_logits = torch.tensor([float("-inf"), 0.0] + extra_components * [float("-inf")]) trans_logits = torch.stack([ left_logits if t % 2 == 0 else right_logits for t in range(num_steps) ]) trans_logits = trans_logits.unsqueeze(-2) hmm_trans_matrix = torch.randn(num_steps, hidden_dim, hidden_dim) switching_trans_matrix = hmm_trans_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1) trans_mvn = random_mvn(( num_steps, num_components, ), hidden_dim) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1) obs_mvn = random_mvn((num_steps, num_components), obs_dim) hmm_trans_mvn_loc = torch.empty(num_steps, hidden_dim) hmm_trans_mvn_cov = torch.empty(num_steps, hidden_dim, hidden_dim) hmm_obs_mvn_loc = torch.empty(num_steps, obs_dim) hmm_obs_mvn_cov = torch.empty(num_steps, obs_dim, obs_dim) for t in range(num_steps): # select relevant bits for hmm given deterministic dynamics in discrete space s = t % 2 # 0, 1, 0, 1, ... hmm_trans_mvn_loc[t] = trans_mvn.loc[t, s] hmm_trans_mvn_cov[t] = trans_mvn.covariance_matrix[t, s] hmm_obs_mvn_loc[t] = obs_mvn.loc[t, s] hmm_obs_mvn_cov[t] = obs_mvn.covariance_matrix[t, s] # scramble matrices in places that should never be accessed given deterministic dynamics in discrete space s = 1 - (t % 2) # 1, 0, 1, 0, ... switching_trans_matrix[t, s, :, :] = torch.rand(hidden_dim, hidden_dim) switching_obs_matrix[t, s, :, :] = torch.rand(hidden_dim, obs_dim) expected_dist = GaussianHMM( dist.MultivariateNormal(init_mvn.loc[1], init_mvn.covariance_matrix[1]), hmm_trans_matrix, dist.MultivariateNormal(hmm_trans_mvn_loc, hmm_trans_mvn_cov), hmm_obs_matrix, dist.MultivariateNormal(hmm_obs_mvn_loc, hmm_obs_mvn_cov)) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, switching_trans_matrix, trans_mvn, switching_obs_matrix, obs_mvn, exact=exact) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape data = obs_mvn.sample()[:, 0, :] assert data.shape == expected_dist.shape() expected_log_prob = expected_dist.log_prob(data) assert expected_log_prob.shape == expected_dist.batch_shape actual_log_prob = actual_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-2, rtol=None)