def tril_with_diag_softplus_and_shift(x, diag_shift=1e-5, name=None):
    """Converts (batch of) vectors to (batch of) lower-triangular scale matrices.

  Args:
    x: (Batch of) `float`-like `Tensor` representing vectors which will be
      transformed into lower-triangular scale matrices with positive diagonal
      elements. Rightmost shape `n` must be such that
      `n = dims * (dims + 1) / 2` for some positive, integer `dims`.
    diag_shift: `Tensor` added to `softplus` transformation of diagonal
      elements.
      Default value: `1e-5`.
    name: A `name_scope` name for operations created by this function.
      Default value: `None` (i.e., "tril_with_diag_softplus_and_shift").

  Returns:
    scale_tril: (Batch of) lower-triangular `Tensor` with `x.dtype` and
      rightmost shape `[dims, dims]` where `n = dims * (dims + 1) / 2` where
      `n = x.shape[-1]`.
  """
    with tf.compat.v1.name_scope(name, 'tril_with_diag_softplus_and_shift',
                                 [x, diag_shift]):
        x = tf.convert_to_tensor(value=x, name='x')
        x = tfp_math.fill_triangular(x)
        diag = softplus_and_shift(tf.linalg.diag_part(x), diag_shift)
        x = tf.linalg.set_diag(x, diag)
        return x
Ejemplo n.º 2
0
 def _forward(self, x):
   return tfp_math.fill_triangular(x, upper=self._upper)