def test_reshape(old_shape, new_shape, dim): gaussian = random_gaussian(old_shape, dim) # reshape to new new = gaussian.reshape(new_shape) assert new.batch_shape == new_shape # reshape back to old g = new.reshape(old_shape) assert_close_gaussian(g, gaussian)
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) actual = _sequential_gaussian_tensordot(g) assert actual.dim() == g.dim() assert actual.batch_shape == batch_shape # Check against hand computation. expected = g[..., 0] for t in range(1, num_steps): expected = gaussian_tensordot(expected, g[..., t], state_dim) assert_close_gaussian(actual, expected)
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) x_mvn = random_mvn(batch_shape, x_dim) Mx_cov = matrix.transpose(-2, -1).matmul( x_mvn.covariance_matrix).matmul(matrix) Mx_loc = matrix.transpose(-2, -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1) mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) assert_close_gaussian(expected, actual)
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gaussian = random_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gaussian[..., beg:end] elif cat_dim == -2: part = gaussian[..., beg:end, :] elif cat_dim == 1: part = gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = Gaussian.cat(parts, cat_dim) assert_close_gaussian(actual, gaussian)