def _default_event_space_bijector(self): return chain_bijector.Chain([ softplus_bijector.Softplus(validate_args=self.validate_args), scale_bijector.Scale(scale=-1., validate_args=self.validate_args), exp_bijector.Log(validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _transformed_beta(self, low=None, peak=None, high=None, temperature=None): low = tf.convert_to_tensor(self.low) if low is None else low peak = tf.convert_to_tensor(self.peak) if peak is None else peak high = tf.convert_to_tensor(self.high) if high is None else high temperature = (tf.convert_to_tensor(self.temperature) if temperature is None else temperature) scale = high - low concentration1 = (1. + temperature * (peak - low) / scale) concentration0 = (1. + temperature * (high - peak) / scale) return transformed_distribution.TransformedDistribution( distribution=beta.Beta(concentration1=concentration1, concentration0=concentration0, allow_nan_stats=self.allow_nan_stats), bijector=chain_bijector.Chain([ shift_bijector.Shift(shift=low), # Broadcasting scale on affine bijector to match batch dimension. # This prevents dimension mismatch for operations like cdf. # Note that `concentration1` incorporates the broadcast of all four # parameters. scale_bijector.Scale( scale=tf.broadcast_to(scale, ps.shape(concentration1))) ]))
def __init__(self, concentration, scale=None, log_scale=None, validate_args=False, allow_nan_stats=True, name='ExpInverseGamma'): """Construct ExpInverseGamma with `concentration` and `scale` parameters. The parameters `concentration` and `scale` (or `log_scale`) must be shaped in a way that supports broadcasting (e.g. `concentration + scale` is a valid operation). Args: concentration: Floating point tensor, the concentration params of the distribution(s). Must contain only positive values. scale: Floating point tensor, the scale params of the distribution(s). Must contain only positive values. Mutually exclusive with `log_scale`. log_scale: Floating point tensor, the natural logarithm of the scale params of the distribution(s). Mutually exclusive with `scale`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `concentration`, `scale`, or `log_scale` are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([concentration, scale, log_scale], dtype_hint=tf.float32) concentration = tensor_util.convert_nonref_to_tensor( concentration, dtype=dtype, name='concentration') scale = tensor_util.convert_nonref_to_tensor(scale, dtype=dtype, name='scale') log_scale = tensor_util.convert_nonref_to_tensor(log_scale, dtype=dtype, name='log_scale') bijector = scale_bijector.Scale(scale=-tf.ones([], dtype=dtype)) to_transform = ExpGamma(concentration=concentration, rate=scale, log_rate=log_scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(ExpInverseGamma, self).__init__(bijector=bijector, distribution=to_transform, validate_args=validate_args, parameters=parameters, name=name)
def _bijector_fn(x0, input_depth, **condition_kwargs): shift, log_scale = shift_and_log_scale_fn(x0, input_depth, **condition_kwargs) bijectors = [] if shift is not None: bijectors.append(shift_lib.Shift(shift)) if log_scale is not None: bijectors.append(scale_lib.Scale(log_scale=log_scale)) return chain_lib.Chain(bijectors)
def _default_event_space_bijector(self): low = tfp_util.DeferredTensor(self.low, lambda x: x) scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low) return chain_bijector.Chain([ shift_bijector.Shift(shift=low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor(self.high, lambda x: x - self.low) else: scale = self.high - self.low return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return chain_bijector.Chain([ shift_bijector.Shift(shift=-np.pi, validate_args=self.validate_args), scale_bijector.Scale(scale=2. * np.pi, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, shift, scale, tailweight, validate_args=False, name="lambertw_tail"): """Construct a location scale heavy-tail Lambert W bijector. The parameters `shift`, `scale`, and `tail` must be shaped in a way that supports broadcasting (e.g. `shift + scale + tail` is a valid operation). Args: shift: Floating point tensor; the shift for centering (uncentering) the input (output) random variable(s). scale: Floating point tensor; the scaling (unscaling) of the input (output) random variable(s). Must contain only positive values. tailweight: Floating point tensor; the tail behaviors of the output random variable(s). Must contain only non-negative values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `shift` and `scale` and `tail` have different `dtype`. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([tailweight, shift, scale], tf.float32) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name="tailweight", dtype=dtype) self._shift = tensor_util.convert_nonref_to_tensor(shift, name="shift", dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name="scale", dtype=dtype) dtype_util.assert_same_float_dtype( (self._tailweight, self._shift, self._scale)) self._shift_and_scale = chain.Chain( [tfb_shift.Shift(self._shift), tfb_scale.Scale(self._scale)]) # 'bijectors' argument in tfb.Chain super class are executed in reverse(!) # order. Hence the ordering in the list must be (3,2,1), not (1,2,3). super(LambertWTail, self).__init__(bijectors=[ self._shift_and_scale, _HeavyTailOnly(tailweight=self._tailweight), invert.Invert(self._shift_and_scale) ], validate_args=validate_args)
def _negative_concentration_bijector(self): # Constructed dynamically so that `scale * reciprocal(concentration)` is # tape-safe. return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), # TODO(b/146568897): Resolve numerical issues by implementing a new # bijector instead of multiplying `scale` by `(1. - 1e-6)`. scale_bijector.Scale( scale=-(self.scale * tf.math.reciprocal(self.concentration) * (1. - 1e-6)), validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/146568897): Resolve numerical issues by implementing a new bijector # instead of multiplying `scale` by `(1. - 1e-6)`. if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor( self.high, lambda x: (x - self.low) * (1. - 1e-6), shape=tf.broadcast_static_shape(self.high.shape, self.low.shape)) else: scale = (self.high - self.low) * (1. - 1e-6) return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None): parameters = dict(locals()) self._initialized = tf.Variable(False, trainable=False) self._m = tf.Variable(tf.zeros(nchan, dtype)) self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp()) self._bijector = invert.Invert( chain.Chain([ scale.Scale(self._s), shift.Shift(self._m), ])) super(ActivationNormalization, self).__init__( validate_args=validate_args, forward_min_event_ndims=1, parameters=parameters, name=name or 'ActivationNormalization')
def _as_trainable_family(distribution): """Substitutes prior distributions with more easily trainable ones.""" with tf.name_scope('as_trainable_family'): if isinstance(distribution, half_normal.HalfNormal): return truncated_normal.TruncatedNormal( loc=0., scale=distribution.scale, low=0., high=distribution.scale * 10.) elif isinstance(distribution, uniform.Uniform): return shift.Shift(distribution.low)( scale_lib.Scale(distribution.high - distribution.low)(beta.Beta( concentration0=tf.ones( distribution.event_shape_tensor(), dtype=distribution.dtype), concentration1=1.))) else: return distribution
def bijector_fn(inputs, ignored_input): """Decorated function to get the RealNVP bijector.""" # Build this so we can handle a user passing a NN that returns a tensor # OR an NN that returns a bijector possible_output = layer(inputs) # We need to produce a bijector, but we do not know if the layer has done # so. We are setting this up to handle 2 possibilities: # 1) The layer outputs a bijector --> all is good # 2) The layer outputs a tensor --> we need to turn it into a bijector. if isinstance(possible_output, bijector.Bijector): output = possible_output elif isinstance(possible_output, tf.Tensor): input_shape = inputs.get_shape().as_list() output_shape = possible_output.get_shape().as_list() assert input_shape[:-1] == output_shape[:-1] c = input_shape[-1] # For layers which output a tensor, we have two possibilities: # 1) There are twice as many output channels as inputs --> the coupling # is affine, meaning there is a scale followed by a shift. # 2) There are an equal number of input and output channels --> the # coupling is additive, meaning there is just a shift if input_shape[-1] == output_shape[-1] // 2: this_scale = scale.Scale( scale_fn(possible_output[..., :c] + 2.)) this_shift = shift.Shift(possible_output[..., c:]) output = this_shift(this_scale) elif input_shape[-1] == output_shape[-1]: output = shift.Shift(possible_output[..., :c]) else: raise ValueError( 'Shape inconsistent with input. Expected shape' '{0} or {1} but tensor was shape {2}'.format( input_shape, tf.concat( [input_shape[:-1], [2 * input_shape[-1]]], 0), output_shape)) else: raise ValueError( 'Expected a bijector or a tensor, but instead got' '{}'.format(possible_output.__class__)) return output
def __init__(self, skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None): """Construct Johnson's SU distributions. The distributions have shape parameteres `tailweight` and `skewness`, mean `loc`, and scale `scale`. The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `skewness + tailweight + loc + scale` is a valid operation). Args: skewness: Floating-point `Tensor`. Skewness of the distribution(s). tailweight: Floating-point `Tensor`. Tail weight of the distribution(s). `tailweight` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if any of skewness, tailweight, loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name or 'JohnsonSU') as name: dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale], tf.float32) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) norm_shift = invert_bijector.Invert( shift_bijector.Shift(shift=self._skewness, validate_args=validate_args)) norm_scale = invert_bijector.Invert( scale_bijector.Scale(scale=self._tailweight, validate_args=validate_args)) sinh = sinh_bijector.Sinh(validate_args=validate_args) scale = scale_bijector.Scale(scale=self._scale, validate_args=validate_args) shift = shift_bijector.Shift(shift=self._loc, validate_args=validate_args) bijector = shift(scale(sinh(norm_scale(norm_shift)))) batch_rank = ps.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) super(JohnsonSU, self).__init__( # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution=normal.Normal(loc=tf.zeros(ps.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats), bijector=bijector, validate_args=validate_args, parameters=parameters, name=name)
global ASVI_SURROGATE_SUBSTITUTIONS if inspect.isclass(condition): condition = lambda distribution, cls=condition: isinstance( # pylint: disable=g-long-lambda distribution, cls) ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn # Default substitutions attempt to express distributions using the most # flexible available parameterization. # pylint: disable=g-long-lambda register_asvi_substitution_rule( half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal( loc=0., scale=dist.scale, low=0., high=dist.scale * 10.)) register_asvi_substitution_rule( uniform.Uniform, lambda dist: shift.Shift(dist.low) (scale_lib.Scale(dist.high - dist.low) (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.)))) register_asvi_substitution_rule( exponential.Exponential, lambda dist: gamma.Gamma(concentration=1., rate=dist.rate)) register_asvi_substitution_rule( chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5)) # pylint: enable=g-long-lambda # TODO(kateslin): Add support for models with prior+likelihood written as # a single JointDistribution. def build_asvi_surrogate_posterior(prior, mean_field=False, initial_prior_weight=0.5, seed=None,
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh'): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Must have a batch shape to which the shapes of `loc`, `scale`, `skewness`, and `tailweight` all broadcast. Default is `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted shape of the parameters. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor( loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor( scale, name='scale', dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: batch_shape = functools.reduce( ps.broadcast_shape, [ps.shape(x) for x in (self._skewness, self._tailweight, self._loc, self._scale)]) distribution = normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh( skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = shift_bijector.Shift(shift=self._loc)( scale_bijector.Scale(scale=self._scale)) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__( distribution=distribution, bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, low=None, high=None, hinge_softness=None, validate_args=False, name='soft_clip'): """Instantiates the SoftClip bijector. Args: low: Optional float `Tensor` lower bound. If `None`, the lower-bound constraint is omitted. Default value: `None`. high: Optional float `Tensor` upper bound. If `None`, the upper-bound constraint is omitted. Default value: `None`. hinge_softness: Optional nonzero float `Tensor`. Controls the softness of the constraint at the boundaries; values outside of the constraint set are mapped into intervals of width approximately `log(2) * hinge_softness` on the interior of each boundary. High softness reserves more space for values outside of the constraint set, leading to greater distortion of inputs *within* the constraint set, but improved numerical stability near the boundaries. Default value: `None` (`1.0`). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name): dtype = dtype_util.common_dtype([low, high, hinge_softness], dtype_hint=tf.float32) low = tensor_util.convert_nonref_to_tensor(low, name='low', dtype=dtype) high = tensor_util.convert_nonref_to_tensor(high, name='high', dtype=dtype) hinge_softness = tensor_util.convert_nonref_to_tensor( hinge_softness, name='hinge_softness', dtype=dtype) softplus_bijector = softplus.Softplus( hinge_softness=hinge_softness) negate = tf.convert_to_tensor(-1., dtype=dtype) components = [] if low is not None and high is not None: # Support reference tensors (eg Variables) for `high` and `low` by # deferring all computation on them until needed. width = tfp_util.DeferredTensor( pretransformed_input=high, transform_fn=lambda high: high - low) negated_shrinkage_factor = tfp_util.DeferredTensor( pretransformed_input=width, transform_fn=lambda w: tf.cast( # pylint: disable=g-long-lambda negate * w / softplus_bijector.forward(w), dtype=dtype)) # Implement the soft constraint from 'Mathematical Details' above: # softclip(x) := -softplus(width - softplus(x - low)) * # (width) / (softplus(width)) + high components = [ shift.Shift(high), scale.Scale(negated_shrinkage_factor), softplus_bijector, shift.Shift(width), scale.Scale(negate), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif low is not None: # Implement a soft lower bound: # softlower(x) := softplus(x - low) + low components = [ shift.Shift(low), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif high is not None: # Implement a soft upper bound: # softupper(x) := -softplus(high - x) + high components = [ shift.Shift(high), scale.Scale(negate), softplus_bijector, scale.Scale(negate), shift.Shift(high) ] self._low = low self._high = high self._hinge_softness = hinge_softness self._chain = chain.Chain(components, validate_args=validate_args) super(SoftClip, self).__init__(forward_min_event_ndims=0, dtype=dtype, validate_args=validate_args, parameters=parameters, is_constant_jacobian=not components, name=name)