def test_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank, y_batch_shape, y_dim, y_rank): x_rank = min(x_rank, x_dim) y_rank = min(y_rank, y_dim) x = random_gaussian(x_batch_shape, x_dim, x_rank) y = random_gaussian(y_batch_shape, y_dim, y_rank) na = x_dim - dot_dims nb = dot_dims nc = y_dim - dot_dims try: torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb]) except RuntimeError: pytest.skip( "Cannot marginalize the common variables of two Gaussians.") z = gaussian_tensordot(x, y, dot_dims) assert z.dim() == x_dim + y_dim - 2 * dot_dims # We make these precision matrices positive definite to test the math x.precision = x.precision + 1e-1 * torch.eye(x.dim()) y.precision = y.precision + 1e-1 * torch.eye(y.dim()) z = gaussian_tensordot(x, y, dot_dims) # compare against broadcasting, adding, and marginalizing precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision, (na, 0, na, 0)) info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0)) covariance = torch.inverse(precision) loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) if info_vec.size(-1) > 0 else info_vec) z_covariance = torch.inverse(z.precision) z_loc = z_covariance.matmul( z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1) assert_close(loc[..., :na], z_loc[..., :na]) assert_close(loc[..., x_dim:], z_loc[..., na:]) assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na]) assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:]) assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na]) assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:]) # Assume a = c = 0, integrate out b # FIXME: this might be not a stable way to compute integral num_samples = 200000 scale = 20 # generate samples in [-10, 10] value_b = torch.rand((num_samples, ) + z.batch_shape + (nb, )) * scale - scale / 2 value_x = pad(value_b, (na, 0)) value_y = pad(value_b, (0, nc)) expect = torch.logsumexp(x.log_density(value_x) + y.log_density(value_y), dim=0) expect += math.log(scale**nb / num_samples) actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), ))) # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value # log densities. assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1)
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + hidden_dim) obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = mvn_to_gaussian(trans_dist) obs = mvn_to_gaussian(obs_dist) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim) logp_h += unrolled_obs.marginalize(right=T * obs_dim) expected_log_prob = logp_oh.log_density( unrolled_data) - logp_h.event_logsumexp() assert_close(actual_log_prob, expected_log_prob)
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) if diag: obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def log_prob(self, value): # We compute a normalized distribution as p(obs,hidden) / p(hidden). logp_oh = self._trans logp_h = self._trans # Combine observation and transition factors. logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim) logp_h += self._obs.marginalize(right=self.obs_dim).event_pad( left=self.hidden_dim) # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian. batch_dim = 1 + max( len(self._init.batch_shape) + 1, len(logp_oh.batch_shape)) batch_shape = (1, ) * (batch_dim - len(logp_oh.batch_shape)) + logp_oh.batch_shape logp = Gaussian.cat( [logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Marginalize out final state. logp_oh, logp_h = logp.event_logsumexp() return logp_oh - logp_h # = log( p(obs,hidden) / p(hidden) )
def filter(self, value): """ Compute posterior over final state given a sequence of observations. :param ~torch.Tensor value: A sequence of observations. :return: A posterior distribution over latent states at the final time step. ``result`` can then be used as ``initial_dist`` in a sequential Pyro model for prediction. :rtype: ~pyro.distributions.MultivariateNormal """ # Combine observation and transition factors. logp = self._trans + self._obs.condition(value).event_pad( left=self.hidden_dim) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Convert to a distribution precision = logp.precision loc = logp.info_vec.unsqueeze(-1).cholesky_solve( precision.cholesky()).squeeze(-1) return MultivariateNormal(loc, precision_matrix=precision, validate_args=self._validate_args)
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) actual = _sequential_gaussian_tensordot(g) assert actual.dim() == g.dim() assert actual.batch_shape == batch_shape # Check against hand computation. expected = g[..., 0] for t in range(1, num_steps): expected = gaussian_tensordot(expected, g[..., t], state_dim) assert_close_gaussian(actual, expected)
def log_prob(self, value): # Combine observation and transition factors. result = self._trans + self._obs.condition(value).event_pad( left=self.hidden_dim) # Eliminate time dimension. result = _sequential_gaussian_tensordot( result.expand(result.batch_shape)) # Combine initial factor. result = gaussian_tensordot(self._init, result, dims=self.hidden_dim) # Marginalize out final state. result = result.event_logsumexp() return result
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) x_mvn = random_mvn(batch_shape, x_dim) Mx_cov = matrix.transpose(-2, -1).matmul( x_mvn.covariance_matrix).matmul(matrix) Mx_loc = matrix.transpose(-2, -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1) mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) assert_close_gaussian(expected, actual)
def _sequential_gaussian_tensordot(gaussian): """ Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] """ assert isinstance(gaussian, Gaussian) assert gaussian.dim() % 2 == 0, "dim is not even" batch_shape = gaussian.batch_shape[:-1] state_dim = gaussian.dim() // 2 while gaussian.batch_shape[-1] > 1: time = gaussian.batch_shape[-1] even_time = time // 2 * 2 even_part = gaussian[..., :even_time] x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) x, y = x_y[..., 0], x_y[..., 1] contracted = gaussian_tensordot(x, y, state_dim) if time > even_time: contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) gaussian = contracted return gaussian[..., 0]
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) if diag: obs_mvn = dist.MultivariateNormal( obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # like | O O O # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) unrolled_like = reduce(operator.add, [ like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) assert_close( d.log_prob(data) + like_dist.log_prob(data), d_posterior.log_prob(data) + log_normalizer) if batch_shape or sample_shape: return # Test mean and covariance. prior = "prior", d, logp posterior = "posterior", d_posterior, logp + unrolled_like for name, d, g in [prior, posterior]: logging.info("testing {} moments".format(name)) with torch.no_grad(): num_samples = 100000 samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim) actual_mean = samples.mean(0) delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) expected_cov = g.precision.cholesky().cholesky_inverse() expected_mean = expected_cov.matmul( g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02)