def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name='fill_scale_tril'): """Instantiates the `FillScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Must be an instance of `tf.__internal__.CompositeTensor` (including `tfb.AutoCompositeTensorBijector`). Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `fill_scale_tril`. Raises: TypeError, if `diag_bijector` is not an instance of `tf.__internal__.CompositeTensor`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if not isinstance(diag_bijector, tf.__internal__.CompositeTensor): raise TypeError('`diag_bijector` must be an instance of ' '`tf.__internal__.CompositeTensor`.') if diag_shift is not None: dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32) diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, name='diag_shift', dtype=dtype) diag_bijector = chain.Chain( [shift.Shift(shift=diag_shift), diag_bijector]) super(FillScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, validate_event_size=False, parameters=parameters, name=name)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. tril_bijector = chain_bijector.Chain([ transform_diagonal_bijector.TransformDiagonal( diag_bijector=softplus_bijector.Softplus( validate_args=self.validate_args), validate_args=self.validate_args), fill_scale_tril_bijector.FillScaleTriL( validate_args=self.validate_args) ], validate_args=self.validate_args) if self.input_output_cholesky: return tril_bijector return chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct( validate_args=self.validate_args), tril_bijector ], validate_args=self.validate_args)
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name="scale_tril"): """Instantiates the `ScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `scale_tril`. """ with tf.compat.v1.name_scope(name, values=[diag_shift]) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if diag_shift is not None: diag_shift = tf.convert_to_tensor(value=diag_shift, dtype=diag_bijector.dtype, name="diag_shift") diag_bijector = chain.Chain([ affine_scalar.AffineScalar(shift=diag_shift), diag_bijector ]) super(ScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, name=name) self._use_tf_function = False # So input bijectors cache EagerTensors.