def _bijector_fn(x, **condition_kwargs): params = shift_and_log_scale_fn(x, **condition_kwargs) if tf.is_tensor(params): shift, log_scale = tf.unstack(params, num=2, axis=-1) else: shift, log_scale = params return affine_scalar.AffineScalar(shift=shift, log_scale=log_scale)
def _transformed_beta(self, low=None, peak=None, high=None, temperature=None): low = tf.convert_to_tensor(self.low) if low is None else low peak = tf.convert_to_tensor(self.peak) if peak is None else peak high = tf.convert_to_tensor(self.high) if high is None else high temperature = (tf.convert_to_tensor(self.temperature) if temperature is None else temperature) scale = high - low concentration1 = (1. + temperature * (peak - low) / scale) concentration0 = (1. + temperature * (high - peak) / scale) return transformed_distribution.TransformedDistribution( distribution=beta.Beta(concentration1=concentration1, concentration0=concentration0, allow_nan_stats=self.allow_nan_stats), bijector=affine_scalar.AffineScalar( shift=low, # Broadcasting scale on affine bijector to match batch dimension. # This prevents dimension mismatch for operations like cdf. # Note that `concentration1` incorporates the broadcast of all four # parameters. scale=tf.broadcast_to(scale, prefer_static.shape(concentration1))))
def _bijector_fn(x, **condition_kwargs): if conditioning is not None: print(x, conditioning) x = tf.concat([conditioning, x], axis=-1) cond_depth = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( conditioning.shape, 1)[-1]) else: cond_depth = 0 params = shift_and_log_scale_fn(x, **condition_kwargs) if tf.is_tensor(params): shift, log_scale = tf.unstack(params, num=2, axis=-1) else: shift, log_scale = params shift = shift[..., cond_depth:] log_scale = log_scale[..., cond_depth:] return affine_scalar.AffineScalar(shift=shift, log_scale=log_scale)
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name="scale_tril"): """Instantiates the `ScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `scale_tril`. """ with tf.compat.v1.name_scope(name, values=[diag_shift]) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if diag_shift is not None: diag_shift = tf.convert_to_tensor(value=diag_shift, dtype=diag_bijector.dtype, name="diag_shift") diag_bijector = chain.Chain([ affine_scalar.AffineScalar(shift=diag_shift), diag_bijector ]) super(ScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, name=name) self._use_tf_function = False # So input bijectors cache EagerTensors.
def _bijector_fn(x0, input_depth, **condition_kwargs): shift, log_scale = shift_and_log_scale_fn( x0, input_depth, **condition_kwargs) return affine_scalar.AffineScalar(shift=shift, log_scale=log_scale)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Default is `tfd.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.compat.v2.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness tailweight = tf.convert_to_tensor(value=tailweight, name="tailweight", dtype=dtype) skewness = tf.convert_to_tensor(value=skewness, name="skewness", dtype=dtype) batch_shape = distribution_util.get_broadcast_shape( loc, scale, tailweight, skewness) # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: loc = distribution_util.with_dependencies(asserts, loc) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = sinh_arcsinh_bijector.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), tailweight=tailweight) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward( tf.convert_to_tensor(value=2, dtype=dtype)) affine = affine_scalar_bijector.AffineScalar( shift=loc, scale=c, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale self._tailweight = tailweight self._skewness = skewness
def _bijector_fn(x0, input_depth, **condition_kwargs): shift, log_scale = shift_and_log_scale_fn( x0, input_depth, **condition_kwargs) # ** First modification is here. return affine_scalar.AffineScalar(shift=shift, scale=log_scale)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh'): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Must have a batch shape to which the shapes of `loc`, `scale`, `skewness`, and `tailweight` all broadcast. Default is `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted shape of the parameters. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: batch_rank = tf.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution = normal.Normal(loc=tf.zeros(tf.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = affine_scalar_bijector.AffineScalar( shift=self._loc, scale=self._scale, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters