def test_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank, y_batch_shape, y_dim, y_rank): x_rank = min(x_rank, x_dim) y_rank = min(y_rank, y_dim) x = random_gaussian(x_batch_shape, x_dim, x_rank) y = random_gaussian(y_batch_shape, y_dim, y_rank) na = x_dim - dot_dims nb = dot_dims nc = y_dim - dot_dims try: torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb]) except RuntimeError: pytest.skip( "Cannot marginalize the common variables of two Gaussians.") z = gaussian_tensordot(x, y, dot_dims) assert z.dim() == x_dim + y_dim - 2 * dot_dims # We make these precision matrices positive definite to test the math x.precision = x.precision + 1e-1 * torch.eye(x.dim()) y.precision = y.precision + 1e-1 * torch.eye(y.dim()) z = gaussian_tensordot(x, y, dot_dims) # compare against broadcasting, adding, and marginalizing precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision, (na, 0, na, 0)) info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0)) covariance = torch.inverse(precision) loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) if info_vec.size(-1) > 0 else info_vec) z_covariance = torch.inverse(z.precision) z_loc = z_covariance.matmul( z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1) assert_close(loc[..., :na], z_loc[..., :na]) assert_close(loc[..., x_dim:], z_loc[..., na:]) assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na]) assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:]) assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na]) assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:]) # Assume a = c = 0, integrate out b # FIXME: this might be not a stable way to compute integral num_samples = 200000 scale = 20 # generate samples in [-10, 10] value_b = torch.rand((num_samples, ) + z.batch_shape + (nb, )) * scale - scale / 2 value_x = pad(value_b, (na, 0)) value_y = pad(value_b, (0, nc)) expect = torch.logsumexp(x.log_density(value_x) + y.log_density(value_y), dim=0) expect += math.log(scale**nb / num_samples) actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), ))) # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value # log densities. assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1)
def test_marginalize_condition(sample_shape, batch_shape, left, right): dim = left + right g = random_gaussian(batch_shape, dim) x = torch.randn(sample_shape + (1, ) * len(batch_shape) + (right, )) assert_close( g.marginalize(left=left).log_density(x), g.condition(x).event_logsumexp())
def test_marginalize(batch_shape, left, right): dim = left + right g = random_gaussian(batch_shape, dim) assert_close( g.marginalize(left=left).event_logsumexp(), g.event_logsumexp()) assert_close( g.marginalize(right=right).event_logsumexp(), g.event_logsumexp())
def test_pad(shape, left, right, dim): expected = random_gaussian(shape, dim) padded = expected.event_pad(left=left, right=right) assert padded.batch_shape == expected.batch_shape assert padded.dim() == left + expected.dim() + right mid = slice(left, padded.dim() - right) assert_close(padded.info_vec[..., mid], expected.info_vec) assert_close(padded.precision[..., mid, mid], expected.precision)
def test_reshape(old_shape, new_shape, dim): gaussian = random_gaussian(old_shape, dim) # reshape to new new = gaussian.reshape(new_shape) assert new.batch_shape == new_shape # reshape back to old g = new.reshape(old_shape) assert_close_gaussian(g, gaussian)
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) actual = _sequential_gaussian_tensordot(g) assert actual.dim() == g.dim() assert actual.batch_shape == batch_shape # Check against hand computation. expected = g[..., 0] for t in range(1, num_steps): expected = gaussian_tensordot(expected, g[..., t], state_dim) assert_close_gaussian(actual, expected)
def test_logsumexp(batch_shape, dim): gaussian = random_gaussian(batch_shape, dim) gaussian.info_vec *= 0.1 # approximately centered gaussian.precision += torch.eye(dim) * 0.1 num_samples = 200000 scale = 10 samples = torch.rand((num_samples, ) + (1, ) * len(batch_shape) + (dim, )) * scale - scale / 2 expected = gaussian.log_density(samples).logsumexp(0) + math.log( scale**dim / num_samples) actual = gaussian.event_logsumexp() assert_close(actual, expected, atol=0.05, rtol=0.05)
def test_condition(sample_shape, batch_shape, left, right): dim = left + right gaussian = random_gaussian(batch_shape, dim) gaussian.precision += torch.eye(dim) * 0.1 value = torch.randn(sample_shape + (1, ) * len(batch_shape) + (dim, )) left_value, right_value = value[..., :left], value[..., left:] conditioned = gaussian.condition(right_value) assert conditioned.batch_shape == sample_shape + gaussian.batch_shape assert conditioned.dim() == left actual = conditioned.log_density(left_value) expected = gaussian.log_density(value) assert_close(actual, expected)
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gaussian = random_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gaussian[..., beg:end] elif cat_dim == -2: part = gaussian[..., beg:end, :] elif cat_dim == 1: part = gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = Gaussian.cat(parts, cat_dim) assert_close_gaussian(actual, gaussian)
def test_sequential_gaussian_filter_sample(sample_shape, batch_shape, state_dim, num_steps): init = random_gaussian(batch_shape, state_dim) trans = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) sample = _sequential_gaussian_filter_sample(init, trans, sample_shape) assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim)
def test_marginalize_shape(batch_shape, left, right): dim = left + right g = random_gaussian(batch_shape, dim) assert g.marginalize(left=left).dim() == right assert g.marginalize(right=right).dim() == left
def test_add(shape, dim): x = random_gaussian(shape, dim) y = random_gaussian(shape, dim) value = torch.randn(dim) assert_close((x + y).log_density(value), x.log_density(value) + y.log_density(value))