def _tril_spherical_uniform(dimension, batch_shape, dtype, seed): """Returns a `Tensor` of samples of lower triangular matrices. Each row of the lower triangular part follows a spherical uniform distribution. Args: dimension: Scalar `int` `Tensor`, representing the dimensionality of the output matrices. batch_shape: Vector-shaped, `int` `Tensor` representing batch shape of output. The output will have shape `batch_shape + [dimension, dimension]`. dtype: TF `dtype` representing `dtype` of output. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: tril_spherical_uniform: `Tensor` with specified `batch_shape` and `dtype` consisting of real values drawn row-wise from a spherical uniform distribution. """ # Essentially, we will draw lower triangular samples where each lower # triangular entry follows a normal distribution, then apply `x / norm(x)` # for each row of the samples. # To avoid possible NaNs, we will use spherical_uniform directly for # the first two rows. assert dimension > 0, '`dimension` needs to be positive.' num_seeds = min(dimension, 3) seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj')) rows = [] paddings_prepend = [[0, 0]] * len(batch_shape) for n in range(1, min(dimension, 2) + 1): rows.append( tf.pad(random_ops.spherical_uniform(shape=batch_shape, dimension=n, dtype=dtype, seed=seeds.pop()), paddings_prepend + [[0, dimension - n]], constant_values=0.)) samples = tf.stack(rows, axis=-2) if dimension > 2: normal_shape = ps.concat( [batch_shape, [dimension * (dimension + 1) // 2 - 3]], axis=0) normal_samples = samplers.normal(shape=normal_shape, dtype=dtype, seed=seeds.pop()) # We fill the first two rows of the triangular matrix with ones. # Note that fill_triangular fills elements in a clockwise spiral. normal_samples = tf.concat([ normal_samples[..., :dimension], tf.ones(ps.concat([batch_shape, [1]], axis=0), dtype=dtype), normal_samples[..., dimension:(2 * dimension - 1)], tf.ones(ps.concat([batch_shape, [2]], axis=0), dtype=dtype), normal_samples[..., (2 * dimension - 1):], ], axis=-1) normal_samples = linalg.fill_triangular(normal_samples, upper=False)[..., 2:, :] remaining_rows = normal_samples / tf.norm( normal_samples, ord=2, axis=-1, keepdims=True) samples = tf.concat([samples, remaining_rows], axis=-2) return samples
def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): """Returns a uniformly random `Tensor` of "correlation-like" matrices. A "correlation-like" matrix is a symmetric square matrix with all entries between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, the ones that are positive semi-definite are exactly the correlation matrices. Args: num_rows: Python `int` dimension of the correlation-like matrices. batch_shape: `Tensor` or Python `tuple` of `int` shape of the batch to return. dtype: `dtype` of the `Tensor` to return. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` and dtype `dtype`. Each entry is in [-1, 1], and each matrix along the bottom two dimensions is symmetric and has 1s on the main diagonal. """ num_entries = num_rows * (num_rows + 1) // 2 ones = tf.ones(shape=[num_entries], dtype=dtype) # It seems wasteful to generate random values for the diagonal since # I am going to throw them away, but `fill_triangular` fills the # diagonal, so I probably need them. # It's not impossible that it would be more efficient to just fill # the whole matrix with random values instead of messing with # `fill_triangular`. Then would need to filter almost half out with # `matrix_band_part`. unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) tril = fill_triangular(unifs) symmetric = tril + tf.linalg.matrix_transpose(tril) diagonal_ones = tf.ones(prefer_static.pad(batch_shape, paddings=[[0, 1]], constant_values=num_rows), dtype=dtype) return tf.linalg.set_diag(symmetric, diagonal_ones)
def _forward(self, x): return fill_triangular(x, upper=self._upper)