Beispiel #1
0
def _tril_spherical_uniform(dimension, batch_shape, dtype, seed):
    """Returns a `Tensor` of samples of lower triangular matrices.

  Each row of the lower triangular part follows a spherical uniform
  distribution.

  Args:
    dimension: Scalar `int` `Tensor`, representing the dimensionality of the
      output matrices.
    batch_shape: Vector-shaped, `int` `Tensor` representing batch shape of
      output. The output will have shape `batch_shape + [dimension, dimension]`.
    dtype: TF `dtype` representing `dtype` of output.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    tril_spherical_uniform: `Tensor` with specified `batch_shape` and `dtype`
      consisting of real values drawn row-wise from a spherical uniform
      distribution.
  """
    # Essentially, we will draw lower triangular samples where each lower
    # triangular entry follows a normal distribution, then apply `x / norm(x)`
    # for each row of the samples.
    # To avoid possible NaNs, we will use spherical_uniform directly for
    # the first two rows.
    assert dimension > 0, '`dimension` needs to be positive.'
    num_seeds = min(dimension, 3)
    seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj'))
    rows = []
    paddings_prepend = [[0, 0]] * len(batch_shape)
    for n in range(1, min(dimension, 2) + 1):
        rows.append(
            tf.pad(random_ops.spherical_uniform(shape=batch_shape,
                                                dimension=n,
                                                dtype=dtype,
                                                seed=seeds.pop()),
                   paddings_prepend + [[0, dimension - n]],
                   constant_values=0.))
    samples = tf.stack(rows, axis=-2)
    if dimension > 2:
        normal_shape = ps.concat(
            [batch_shape, [dimension * (dimension + 1) // 2 - 3]], axis=0)
        normal_samples = samplers.normal(shape=normal_shape,
                                         dtype=dtype,
                                         seed=seeds.pop())
        # We fill the first two rows of the triangular matrix with ones.
        # Note that fill_triangular fills elements in a clockwise spiral.
        normal_samples = tf.concat([
            normal_samples[..., :dimension],
            tf.ones(ps.concat([batch_shape, [1]], axis=0), dtype=dtype),
            normal_samples[..., dimension:(2 * dimension - 1)],
            tf.ones(ps.concat([batch_shape, [2]], axis=0), dtype=dtype),
            normal_samples[..., (2 * dimension - 1):],
        ],
                                   axis=-1)
        normal_samples = linalg.fill_triangular(normal_samples,
                                                upper=False)[..., 2:, :]
        remaining_rows = normal_samples / tf.norm(
            normal_samples, ord=2, axis=-1, keepdims=True)
        samples = tf.concat([samples, remaining_rows], axis=-2)
    return samples
def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
    """Returns a uniformly random `Tensor` of "correlation-like" matrices.

  A "correlation-like" matrix is a symmetric square matrix with all entries
  between -1 and 1 (inclusive) and 1s on the main diagonal.  Of these,
  the ones that are positive semi-definite are exactly the correlation
  matrices.

  Args:
    num_rows: Python `int` dimension of the correlation-like matrices.
    batch_shape: `Tensor` or Python `tuple` of `int` shape of the
      batch to return.
    dtype: `dtype` of the `Tensor` to return.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
      and dtype `dtype`.  Each entry is in [-1, 1], and each matrix
      along the bottom two dimensions is symmetric and has 1s on the
      main diagonal.
  """
    num_entries = num_rows * (num_rows + 1) // 2
    ones = tf.ones(shape=[num_entries], dtype=dtype)
    # It seems wasteful to generate random values for the diagonal since
    # I am going to throw them away, but `fill_triangular` fills the
    # diagonal, so I probably need them.
    # It's not impossible that it would be more efficient to just fill
    # the whole matrix with random values instead of messing with
    # `fill_triangular`.  Then would need to filter almost half out with
    # `matrix_band_part`.
    unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
    tril = fill_triangular(unifs)
    symmetric = tril + tf.linalg.matrix_transpose(tril)
    diagonal_ones = tf.ones(prefer_static.pad(batch_shape,
                                              paddings=[[0, 1]],
                                              constant_values=num_rows),
                            dtype=dtype)
    return tf.linalg.set_diag(symmetric, diagonal_ones)
 def _forward(self, x):
   return fill_triangular(x, upper=self._upper)