def test_add(shape, dim): x = random_gamma_gaussian(shape, dim) y = random_gamma_gaussian(shape, dim) value = torch.randn(dim) s = torch.randn(()).exp() assert_close((x + y).log_density(value, s), x.log_density(value, s) + y.log_density(value, s))
def test_gamma_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank, y_batch_shape, y_dim, y_rank): x_rank = min(x_rank, x_dim) y_rank = min(y_rank, y_dim) x = random_gamma_gaussian(x_batch_shape, x_dim, x_rank) y = random_gamma_gaussian(y_batch_shape, y_dim, y_rank) na = x_dim - dot_dims nb = dot_dims nc = y_dim - dot_dims try: torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb]) except RuntimeError: pytest.skip( "Cannot marginalize the common variables of two Gaussians.") z = gamma_gaussian_tensordot(x, y, dot_dims) assert z.dim() == x_dim + y_dim - 2 * dot_dims # We make these precision matrices positive definite to test the math x.precision = x.precision + 3 * torch.eye(x.dim()) y.precision = y.precision + 3 * torch.eye(y.dim()) z = gamma_gaussian_tensordot(x, y, dot_dims) # compare against broadcasting, adding, and marginalizing precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision, (na, 0, na, 0)) info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0)) covariance = torch.inverse(precision) loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) if info_vec.size(-1) > 0 else info_vec) z_covariance = torch.inverse(z.precision) z_loc = z_covariance.matmul( z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1) assert_close(loc[..., :na], z_loc[..., :na]) assert_close(loc[..., x_dim:], z_loc[..., na:]) assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na]) assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:]) assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na]) assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:]) s = torch.randn(z.batch_shape).exp() # Assume a = c = 0, integrate out b num_samples = 200000 scale = 10 # generate samples in [-10, 10] value_b = torch.rand((num_samples, ) + z.batch_shape + (nb, )) * scale - scale / 2 value_x = pad(value_b, (na, 0)) value_y = pad(value_b, (0, nc)) expect = torch.logsumexp(x.log_density(value_x, s) + y.log_density(value_y, s), dim=0) expect += math.log(scale**nb / num_samples) actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), )), s) assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1)
def test_marginalize_condition(sample_shape, batch_shape, left, right): dim = left + right g = random_gamma_gaussian(batch_shape, dim) x = torch.randn(sample_shape + (1, ) * len(batch_shape) + (right, )) s = torch.randn(batch_shape).exp() assert_close( g.marginalize(left=left).log_density(x, s), g.condition(x).event_logsumexp().log_density(s))
def test_pad(shape, left, right, dim): expected = random_gamma_gaussian(shape, dim) padded = expected.event_pad(left=left, right=right) assert padded.batch_shape == expected.batch_shape assert padded.dim() == left + expected.dim() + right mid = slice(left, padded.dim() - right) assert_close(padded.info_vec[..., mid], expected.info_vec) assert_close(padded.precision[..., mid, mid], expected.precision)
def test_reshape(old_shape, new_shape, dim): gamma_gaussian = random_gamma_gaussian(old_shape, dim) # reshape to new new = gamma_gaussian.reshape(new_shape) assert new.batch_shape == new_shape # reshape back to old g = new.reshape(old_shape) assert_close_gamma_gaussian(g, gamma_gaussian)
def test_marginalize(batch_shape, left, right): dim = left + right g = random_gamma_gaussian(batch_shape, dim) s = torch.randn(batch_shape).exp() assert_close( g.marginalize(left=left).event_logsumexp().log_density(s), g.event_logsumexp().log_density(s)) assert_close( g.marginalize(right=right).event_logsumexp().log_density(s), g.event_logsumexp().log_density(s))
def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gamma_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) actual = _sequential_gamma_gaussian_tensordot(g) assert actual.dim() == g.dim() assert actual.batch_shape == batch_shape # Check against hand computation. expected = g[..., 0] for t in range(1, num_steps): expected = gamma_gaussian_tensordot(expected, g[..., t], state_dim) assert_close_gamma_gaussian(actual, expected)
def test_logsumexp(batch_shape, dim): g = random_gamma_gaussian(batch_shape, dim) g.info_vec *= 0.1 # approximately centered g.precision += torch.eye(dim) * 0.1 s = torch.randn(batch_shape).exp() + 0.2 num_samples = 200000 scale = 10 samples = torch.rand((num_samples, ) + (1, ) * len(batch_shape) + (dim, )) * scale - scale / 2 expected = g.log_density(samples, s).logsumexp(0) + math.log( scale**dim / num_samples) actual = g.event_logsumexp().log_density(s) assert_close(actual, expected, atol=0.05, rtol=0.05)
def test_condition(sample_shape, batch_shape, left, right): dim = left + right g = random_gamma_gaussian(batch_shape, dim) g.precision += torch.eye(dim) * 0.1 value = torch.randn(sample_shape + (1, ) * len(batch_shape) + (dim, )) left_value, right_value = value[..., :left], value[..., left:] conditioned = g.condition(right_value) assert conditioned.batch_shape == sample_shape + g.batch_shape assert conditioned.dim() == left s = torch.randn(batch_shape).exp() actual = conditioned.log_density(left_value, s) expected = g.log_density(value, s) assert_close(actual, expected)
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gamma_gaussian = random_gamma_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gamma_gaussian[..., beg:end] elif cat_dim == -2: part = gamma_gaussian[..., beg:end, :] elif cat_dim == 1: part = gamma_gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = GammaGaussian.cat(parts, cat_dim) assert_close_gamma_gaussian(actual, gamma_gaussian)
def test_marginalize_shape(batch_shape, left, right): dim = left + right g = random_gamma_gaussian(batch_shape, dim) assert g.marginalize(left=left).dim() == right assert g.marginalize(right=right).dim() == left