def test_reshape(old_shape, new_shape, dim): gamma_gaussian = random_gamma_gaussian(old_shape, dim) # reshape to new new = gamma_gaussian.reshape(new_shape) assert new.batch_shape == new_shape # reshape back to old g = new.reshape(old_shape) assert_close_gamma_gaussian(g, gamma_gaussian)
def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gamma_gaussian(batch_shape + (num_steps, ), state_dim + state_dim) actual = _sequential_gamma_gaussian_tensordot(g) assert actual.dim() == g.dim() assert actual.batch_shape == batch_shape # Check against hand computation. expected = g[..., 0] for t in range(1, num_steps): expected = gamma_gaussian_tensordot(expected, g[..., t], state_dim) assert_close_gamma_gaussian(actual, expected)
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gamma_gaussian = random_gamma_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gamma_gaussian[..., beg:end] elif cat_dim == -2: part = gamma_gaussian[..., beg:end, :] elif cat_dim == 1: part = gamma_gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = GammaGaussian.cat(parts, cat_dim) assert_close_gamma_gaussian(actual, gamma_gaussian)