def _inverse_event_shape(self, input_shape): if not input_shape.rank: return input_shape n = input_shape[-1] if n is not None: n -= 1 y_shape = input_shape[:-2].concatenate([n, n]) return fill_triangular.FillTriangular().inverse_event_shape(y_shape)
def _forward_event_shape(self, input_shape): if tensorshape_util.rank(input_shape) is None: return input_shape tril_shape = fill_triangular.FillTriangular().forward_event_shape( input_shape) n = tril_shape[-1] if n is not None: n += 1 return tril_shape[:-2].concatenate([n, n])
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name='fill_scale_tril'): """Instantiates the `FillScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Must be an instance of `tf.__internal__.CompositeTensor` (including `tfb.AutoCompositeTensorBijector`). Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `fill_scale_tril`. Raises: TypeError, if `diag_bijector` is not an instance of `tf.__internal__.CompositeTensor`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if not isinstance(diag_bijector, tf.__internal__.CompositeTensor): raise TypeError('`diag_bijector` must be an instance of ' '`tf.__internal__.CompositeTensor`.') if diag_shift is not None: dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32) diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, name='diag_shift', dtype=dtype) diag_bijector = chain.Chain( [shift.Shift(shift=diag_shift), diag_bijector]) super(FillScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, validate_event_size=False, parameters=parameters, name=name)
def _inverse(self, y): n = ps.shape(y)[-1] batch_shape = ps.shape(y)[:-2] # Extract the reciprocal of the row norms from the diagonal. diag = tf.linalg.diag_part(y)[..., tf.newaxis] # Set the diagonal to 0s. y = tf.linalg.set_diag( y, tf.zeros(ps.concat([batch_shape, [n]], axis=-1), dtype=y.dtype)) # Multiply with the norm (or divide by its reciprocal) to recover the # unconstrained reals in the (strictly) lower triangular part. x = y / diag # Remove the first row and last column before inverting the FillTriangular # transformation. return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name="scale_tril"): """Instantiates the `ScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `scale_tril`. """ with tf.compat.v1.name_scope(name, values=[diag_shift]) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if diag_shift is not None: diag_shift = tf.convert_to_tensor(value=diag_shift, dtype=diag_bijector.dtype, name="diag_shift") diag_bijector = chain.Chain([ affine_scalar.AffineScalar(shift=diag_shift), diag_bijector ]) super(ScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, name=name) self._use_tf_function = False # So input bijectors cache EagerTensors.
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = ps.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = ps.rank(y) paddings = ps.concat( [ps.zeros([rank - 2, 2], dtype=tf.int32), [[1, 0], [0, 1]]], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = ps.shape(y)[-1] diag = tf.ones(ps.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def _inverse_event_shape_tensor(self, input_shape): n = input_shape[-1] - 1 y_shape = tf.concat([input_shape[:-2], [n, n]], axis=-1) return fill_triangular.FillTriangular().inverse_event_shape_tensor( y_shape)
def _forward_event_shape_tensor(self, input_shape): tril_shape = fill_triangular.FillTriangular( ).forward_event_shape_tensor(input_shape) n = tril_shape[-1] + 1 return tf.concat([tril_shape[:-2], [n, n]], axis=-1)