def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape, beta_shape, dim): rank = dim + dim log_normalizer = torch.randn(log_normalizer_shape) info_vec = torch.randn(info_vec_shape + (dim, )) precision = torch.randn(precision_shape + (dim, rank)) precision = precision.matmul(precision.transpose(-1, -2)) alpha = torch.randn(alpha_shape).exp() beta = torch.randn(beta_shape).exp() gamma_gaussian = GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) expected_shape = extra_shape + broadcast_shape( log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape, beta_shape) actual = gamma_gaussian.expand(expected_shape) assert actual.batch_shape == expected_shape
def random_gamma_gaussian(batch_shape, dim, rank=None): """ Generate a random Gaussian for testing. """ if rank is None: rank = dim + dim log_normalizer = torch.randn(batch_shape) loc = torch.randn(batch_shape + (dim, )) samples = torch.randn(batch_shape + (dim, rank)) precision = torch.matmul(samples, samples.transpose(-2, -1)) if dim > 0: info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) else: info_vec = loc alpha = torch.randn(batch_shape).exp() + 0.5 * dim - 1 beta = torch.randn(batch_shape).exp() + 0.5 * (info_vec * loc).sum(-1) result = GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) assert result.dim() == dim assert result.batch_shape == batch_shape return result
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gamma_gaussian = random_gamma_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gamma_gaussian[..., beg:end] elif cat_dim == -2: part = gamma_gaussian[..., beg:end, :] elif cat_dim == 1: part = gamma_gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = GammaGaussian.cat(parts, cat_dim) assert_close_gamma_gaussian(actual, gamma_gaussian)
def _sequential_gamma_gaussian_tensordot(gamma_gaussian): """ Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes:: x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] """ assert isinstance(gamma_gaussian, GammaGaussian) assert gamma_gaussian.dim() % 2 == 0, "dim is not even" batch_shape = gamma_gaussian.batch_shape[:-1] state_dim = gamma_gaussian.dim() // 2 while gamma_gaussian.batch_shape[-1] > 1: time = gamma_gaussian.batch_shape[-1] even_time = time // 2 * 2 even_part = gamma_gaussian[..., :even_time] x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) x, y = x_y[..., 0], x_y[..., 1] contracted = gamma_gaussian_tensordot(x, y, state_dim) if time > even_time: contracted = GammaGaussian.cat( (contracted, gamma_gaussian[..., -1:]), dim=-1) gamma_gaussian = contracted return gamma_gaussian[..., 0]