def _default_event_space_bijector(self): return softplus_bijector.Softplus(validate_args=self.validate_args)
def _default_event_space_bijector(self): return chain_bijector.Chain([ reciprocal_bijector.Reciprocal(validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _parameter_properties(cls, dtype, num_classes=None): # pylint: disable=g-long-lambda return dict(power=parameter_properties.ParameterProperties( default_constraining_bijector_fn=( lambda: softplus_bijector.Softplus(low=tf.convert_to_tensor( 1. + dtype_util.eps(dtype), dtype=dtype)))))
def _parameter_properties(cls, dtype): return dict(loc=parameter_properties.ParameterProperties(), scale=parameter_properties.ParameterProperties( default_constraining_bijector_fn=( lambda: softplus_bijector.Softplus(low=dtype_util. eps(dtype)))))
def __init__(self, low=None, high=None, hinge_softness=None, validate_args=False, name='soft_clip'): """Instantiates the SoftClip bijector. Args: low: Optional float `Tensor` lower bound. If `None`, the lower-bound constraint is omitted. Default value: `None`. high: Optional float `Tensor` upper bound. If `None`, the upper-bound constraint is omitted. Default value: `None`. hinge_softness: Optional nonzero float `Tensor`. Controls the softness of the constraint at the boundaries; values outside of the constraint set are mapped into intervals of width approximately `log(2) * hinge_softness` on the interior of each boundary. High softness reserves more space for values outside of the constraint set, leading to greater distortion of inputs *within* the constraint set, but improved numerical stability near the boundaries. Default value: `None` (`1.0`). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ with tf.name_scope(name): dtype = dtype_util.common_dtype([low, high, hinge_softness], dtype_hint=tf.float32) low = tensor_util.convert_nonref_to_tensor(low, name='low', dtype=dtype) high = tensor_util.convert_nonref_to_tensor(high, name='high', dtype=dtype) hinge_softness = tensor_util.convert_nonref_to_tensor( hinge_softness, name='hinge_softness', dtype=dtype) softplus_bijector = softplus.Softplus( hinge_softness=hinge_softness) negate = tf.convert_to_tensor(-1., dtype=dtype) components = [] if low is not None and high is not None: # Support reference tensors (eg Variables) for `high` and `low` by # deferring all computation on them until needed. width = tfp_util.DeferredTensor( pretransformed_input=high, transform_fn=lambda high: high - low) negated_shrinkage_factor = tfp_util.DeferredTensor( pretransformed_input=width, transform_fn=lambda w: tf.cast( # pylint: disable=g-long-lambda negate * w / softplus_bijector.forward(w), dtype=dtype)) # Implement the soft constraint from 'Mathematical Details' above: # softclip(x) := -softplus(width - softplus(x - low)) * # (width) / (softplus(width)) + high components = [ shift.Shift(high), scale.Scale(negated_shrinkage_factor), softplus_bijector, shift.Shift(width), scale.Scale(negate), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif low is not None: # Implement a soft lower bound: # softlower(x) := softplus(x - low) + low components = [ shift.Shift(low), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif high is not None: # Implement a soft upper bound: # softupper(x) := -softplus(high - x) + high components = [ shift.Shift(high), scale.Scale(negate), softplus_bijector, scale.Scale(negate), shift.Shift(high) ] self._low = low self._high = high self._hinge_softness = hinge_softness self._chain = chain.Chain(components, validate_args=validate_args) super(SoftClip, self).__init__(forward_min_event_ndims=0, dtype=dtype, validate_args=validate_args, is_constant_jacobian=not components, name=name)
def _parameter_properties(cls, dtype): from tensorflow_probability.python.bijectors import softplus # pylint:disable=g-import-not-at-top return dict( constant=parameter_properties.ParameterProperties( default_constraining_bijector_fn=( lambda: softplus.Softplus(low=dtype_util.eps(dtype)))))