def test_affine_normal(batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) loc = torch.randn(batch_shape + (y_dim, )) scale = torch.randn(batch_shape + (y_dim, )).exp() y = torch.randn(batch_shape + (y_dim, )) normal = dist.Normal(loc, scale).to_event(1) actual = matrix_and_mvn_to_gaussian(matrix, normal) assert isinstance(actual, AffineNormal) actual_like = actual.condition(y) assert isinstance(actual_like, Gaussian) mvn = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed()) expected = matrix_and_mvn_to_gaussian(matrix, mvn) assert isinstance(expected, Gaussian) expected_like = expected.condition(y) assert isinstance(expected_like, Gaussian) assert_close(actual_like.log_normalizer, expected_like.log_normalizer) assert_close(actual_like.info_vec, expected_like.info_vec) assert_close(actual_like.precision, expected_like.precision) x = torch.randn(batch_shape + (x_dim, )) permute_actual = actual.left_condition(x) assert isinstance(permute_actual, AffineNormal) permute_actual = permute_actual.to_gaussian() permute_expected = expected.left_condition(y) assert isinstance(permute_expected, Gaussian) assert_close(permute_actual.log_normalizer, permute_actual.log_normalizer) assert_close(permute_actual.info_vec, permute_actual.info_vec) assert_close(permute_actual.precision, permute_actual.precision)
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) if diag: obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gaussian(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) xy_mvn = random_mvn(batch_shape, x_dim + y_dim) gaussian = matrix_and_mvn_to_gaussian(matrix, y_mvn) + mvn_to_gaussian(xy_mvn) xy = torch.randn(sample_shape + (1, ) * len(batch_shape) + (x_dim + y_dim, )) x, y = xy[..., :x_dim], xy[..., x_dim:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = gaussian.log_density(xy) expected_log_prob = xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred) assert_close(actual_log_prob, expected_log_prob)
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert (isinstance(observation_dist, torch.distributions.MultivariateNormal) or (isinstance(observation_dist, torch.distributions.Independent) and isinstance(observation_dist.base_dist, torch.distributions.Normal))) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) super(GaussianHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = matrix_and_mvn_to_gaussian(transition_matrix, transition_dist) self._obs = matrix_and_mvn_to_gaussian(observation_matrix, observation_dist)
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) x_mvn = random_mvn(batch_shape, x_dim) Mx_cov = matrix.transpose(-2, -1).matmul( x_mvn.covariance_matrix).matmul(matrix) Mx_loc = matrix.transpose(-2, -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1) mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) assert_close_gaussian(expected, actual)
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) if diag: obs_mvn = dist.MultivariateNormal( obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # like | O O O # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) unrolled_like = reduce(operator.add, [ like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) assert_close( d.log_prob(data) + like_dist.log_prob(data), d_posterior.log_prob(data) + log_normalizer) if batch_shape or sample_shape: return # Test mean and covariance. prior = "prior", d, logp posterior = "posterior", d_posterior, logp + unrolled_like for name, d, g in [prior, posterior]: logging.info("testing {} moments".format(name)) with torch.no_grad(): num_samples = 100000 samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim) actual_mean = samples.mean(0) delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) expected_cov = g.precision.cholesky().cholesky_inverse() expected_mean = expected_cov.matmul( g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02)
def _obs(self): return matrix_and_mvn_to_gaussian(self.observation_matrix, self.observation_dist)
def _trans(self): return matrix_and_mvn_to_gaussian(self.transition_matrix, self.transition_dist)