def tril_with_diag_softplus_and_shift(x, diag_shift=1e-5, name=None): """Converts (batch of) vectors to (batch of) lower-triangular scale matrices. Args: x: (Batch of) `float`-like `Tensor` representing vectors which will be transformed into lower-triangular scale matrices with positive diagonal elements. Rightmost shape `n` must be such that `n = dims * (dims + 1) / 2` for some positive, integer `dims`. diag_shift: `Tensor` added to `softplus` transformation of diagonal elements. Default value: `1e-5`. name: A `name_scope` name for operations created by this function. Default value: `None` (i.e., "tril_with_diag_softplus_and_shift"). Returns: scale_tril: (Batch of) lower-triangular `Tensor` with `x.dtype` and rightmost shape `[dims, dims]` where `n = dims * (dims + 1) / 2` where `n = x.shape[-1]`. """ with tf.compat.v1.name_scope(name, 'tril_with_diag_softplus_and_shift', [x, diag_shift]): x = tf.convert_to_tensor(value=x, name='x') x = tfp_math.fill_triangular(x) diag = softplus_and_shift(tf.linalg.diag_part(x), diag_shift) x = tf.linalg.set_diag(x, diag) return x
def _forward(self, x): return tfp_math.fill_triangular(x, upper=self._upper)