Ejemplo n.º 1
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name='fill_scale_tril'):
        """Instantiates the `FillScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive. Must be an instance of `tf.__internal__.CompositeTensor`
        (including `tfb.AutoCompositeTensorBijector`).
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `fill_scale_tril`.

    Raises:
      TypeError, if `diag_bijector` is not an instance of
        `tf.__internal__.CompositeTensor`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)
            if not isinstance(diag_bijector, tf.__internal__.CompositeTensor):
                raise TypeError('`diag_bijector` must be an instance of '
                                '`tf.__internal__.CompositeTensor`.')

            if diag_shift is not None:
                dtype = dtype_util.common_dtype([diag_bijector, diag_shift],
                                                tf.float32)
                diag_shift = tensor_util.convert_nonref_to_tensor(
                    diag_shift, name='diag_shift', dtype=dtype)
                diag_bijector = chain.Chain(
                    [shift.Shift(shift=diag_shift), diag_bijector])

            super(FillScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                                validate_args=validate_args,
                                                validate_event_size=False,
                                                parameters=parameters,
                                                name=name)
Ejemplo n.º 2
0
 def _default_event_space_bijector(self):
   # TODO(b/145620027) Finalize choice of bijector.
   tril_bijector = chain_bijector.Chain([
       transform_diagonal_bijector.TransformDiagonal(
           diag_bijector=softplus_bijector.Softplus(
               validate_args=self.validate_args),
           validate_args=self.validate_args),
       fill_scale_tril_bijector.FillScaleTriL(
           validate_args=self.validate_args)
   ], validate_args=self.validate_args)
   if self.input_output_cholesky:
     return tril_bijector
   return chain_bijector.Chain([
       cholesky_outer_product_bijector.CholeskyOuterProduct(
           validate_args=self.validate_args),
       tril_bijector
   ], validate_args=self.validate_args)
Ejemplo n.º 3
0
    def __init__(self,
                 diag_bijector=None,
                 diag_shift=1e-5,
                 validate_args=False,
                 name="scale_tril"):
        """Instantiates the `ScaleTriL` bijector.

    Args:
      diag_bijector: `Bijector` instance, used to transform the output diagonal
        to be positive.
        Default value: `None` (i.e., `tfb.Softplus()`).
      diag_shift: Float value broadcastable and added to all diagonal entries
        after applying the `diag_bijector`. Setting a positive
        value forces the output diagonal entries to be positive, but
        prevents inverting the transformation for matrices with
        diagonal entries less than this value.
        Default value: `1e-5`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False` (i.e., arguments are not validated).
      name: Python `str` name given to ops managed by this object.
        Default value: `scale_tril`.
    """
        with tf.compat.v1.name_scope(name, values=[diag_shift]) as name:
            if diag_bijector is None:
                diag_bijector = softplus.Softplus(validate_args=validate_args)

            if diag_shift is not None:
                diag_shift = tf.convert_to_tensor(value=diag_shift,
                                                  dtype=diag_bijector.dtype,
                                                  name="diag_shift")
                diag_bijector = chain.Chain([
                    affine_scalar.AffineScalar(shift=diag_shift), diag_bijector
                ])

            super(ScaleTriL, self).__init__([
                transform_diagonal.TransformDiagonal(
                    diag_bijector=diag_bijector),
                fill_triangular.FillTriangular()
            ],
                                            validate_args=validate_args,
                                            name=name)
            self._use_tf_function = False  # So input bijectors cache EagerTensors.