def _inverse_event_shape(self, input_shape):
     if not input_shape.rank:
         return input_shape
     n = input_shape[-1]
     if n is not None:
         n -= 1
     y_shape = input_shape[:-2].concatenate([n, n])
     return fill_triangular.FillTriangular().inverse_event_shape(y_shape)
 def _forward_event_shape(self, input_shape):
     if tensorshape_util.rank(input_shape) is None:
         return input_shape
     tril_shape = fill_triangular.FillTriangular().forward_event_shape(
         input_shape)
     n = tril_shape[-1]
     if n is not None:
         n += 1
     return tril_shape[:-2].concatenate([n, n])
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name='fill_scale_tril'):
        """Instantiates the `FillScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive. Must be an instance of `tf.__internal__.CompositeTensor`
        (including `tfb.AutoCompositeTensorBijector`).
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `fill_scale_tril`.

    Raises:
      TypeError, if `diag_bijector` is not an instance of
        `tf.__internal__.CompositeTensor`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)
            if not isinstance(diag_bijector, tf.__internal__.CompositeTensor):
                raise TypeError('`diag_bijector` must be an instance of '
                                '`tf.__internal__.CompositeTensor`.')

            if diag_shift is not None:
                dtype = dtype_util.common_dtype([diag_bijector, diag_shift],
                                                tf.float32)
                diag_shift = tensor_util.convert_nonref_to_tensor(
                    diag_shift, name='diag_shift', dtype=dtype)
                diag_bijector = chain.Chain(
                    [shift.Shift(shift=diag_shift), diag_bijector])

            super(FillScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                                validate_args=validate_args,
                                                validate_event_size=False,
                                                parameters=parameters,
                                                name=name)
    def _inverse(self, y):
        n = ps.shape(y)[-1]
        batch_shape = ps.shape(y)[:-2]

        # Extract the reciprocal of the row norms from the diagonal.
        diag = tf.linalg.diag_part(y)[..., tf.newaxis]

        # Set the diagonal to 0s.
        y = tf.linalg.set_diag(
            y, tf.zeros(ps.concat([batch_shape, [n]], axis=-1), dtype=y.dtype))

        # Multiply with the norm (or divide by its reciprocal) to recover the
        # unconstrained reals in the (strictly) lower triangular part.
        x = y / diag

        # Remove the first row and last column before inverting the FillTriangular
        # transformation.
        return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])
Esempio n. 5
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name="scale_tril"):
        """Instantiates the `ScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive.
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `scale_tril`.
    """
        with tf.compat.v1.name_scope(name, values=[diag_shift]) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)

            if diag_shift is not None:
                diag_shift = tf.convert_to_tensor(value=diag_shift,
                                                  dtype=diag_bijector.dtype,
                                                  name="diag_shift")
                diag_bijector = chain.Chain([
                    affine_scalar.AffineScalar(shift=diag_shift), diag_bijector
                ])

            super(ScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                            validate_args=validate_args,
                                            name=name)
            self._use_tf_function = False  # So input bijectors cache EagerTensors.
    def _forward(self, x):
        x = tf.convert_to_tensor(x, name='x')
        batch_shape = ps.shape(x)[:-1]

        # Pad zeros on the top row and right column.
        y = fill_triangular.FillTriangular().forward(x)
        rank = ps.rank(y)
        paddings = ps.concat(
            [ps.zeros([rank - 2, 2], dtype=tf.int32), [[1, 0], [0, 1]]],
            axis=0)
        y = tf.pad(y, paddings)

        # Set diagonal to 1s.
        n = ps.shape(y)[-1]
        diag = tf.ones(ps.concat([batch_shape, [n]], axis=-1), dtype=x.dtype)
        y = tf.linalg.set_diag(y, diag)

        # Normalize each row to have Euclidean (L2) norm 1.
        y /= tf.norm(y, axis=-1)[..., tf.newaxis]
        return y
 def _inverse_event_shape_tensor(self, input_shape):
     n = input_shape[-1] - 1
     y_shape = tf.concat([input_shape[:-2], [n, n]], axis=-1)
     return fill_triangular.FillTriangular().inverse_event_shape_tensor(
         y_shape)
 def _forward_event_shape_tensor(self, input_shape):
     tril_shape = fill_triangular.FillTriangular(
     ).forward_event_shape_tensor(input_shape)
     n = tril_shape[-1] + 1
     return tf.concat([tril_shape[:-2], [n, n]], axis=-1)