def __init__(self, loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalFullCovariance'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `covariance_matrix` arguments. The `event_shape` is given by last dimension of the matrix implied by `covariance_matrix`. The last dimension of `loc` (if provided) must broadcast with this. A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive definite matrix. In other words it is (real) symmetric with all eigenvalues strictly positive. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. covariance_matrix: Floating-point, symmetric positive definite `Tensor` of same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` is ignored, so if `covariance_matrix` is not symmetric no error will be raised (unless `validate_args is True`). `covariance_matrix` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with tf.name_scope(name) as name: with tf.name_scope('init'): dtype = dtype_util.common_dtype([loc, covariance_matrix], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name='loc', dtype=dtype) if covariance_matrix is None: scale_tril = None else: covariance_matrix = tf.convert_to_tensor( covariance_matrix, name='covariance_matrix', dtype=dtype) if validate_args: covariance_matrix = distribution_util.with_dependencies( [ assert_util.assert_near( covariance_matrix, tf.linalg.matrix_transpose( covariance_matrix), message='Matrix was not symmetric') ], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = tf.linalg.cholesky(covariance_matrix) super(MultivariateNormalFullCovariance, self).__init__(loc=loc, scale_tril=scale_tril, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="VectorSinhArcsinhDiag"): """Construct VectorSinhArcsinhDiag distribution on `R^k`. The arguments `scale_diag` and `scale_identity_multiplier` combine to define the diagonal `scale` referred to in this class docstring: ```none scale = diag(scale_diag + scale_identity_multiplier * ones(k)) ``` The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scale-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. skewness: Skewness parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. tailweight: Tailweight parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is `tfd.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorSinhArcsinhDiag sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ loc, scale_diag, scale_identity_multiplier, skewness, tailweight ], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=dtype) tailweight = 1. if tailweight is None else tailweight skewness = 0. if skewness is None else skewness # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: # 1. get shapes from looking at loc and the two scale args. # 2. combine scale_diag with scale_identity_multiplier, which gives us # 'scale', which in turn gives us 'C'. scale_linop = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=False, assert_positive=False, dtype=dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale_linop) # scale_linop.diag_part() is efficient since it is a diag type linop. scale_diag_part = scale_linop.diag_part() dtype = scale_diag_part.dtype if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: scale_diag_part = distribution_util.with_dependencies( asserts, scale_diag_part) # Make the SAS bijector, 'F'. skewness = tf.convert_to_tensor(skewness, dtype=dtype, name="skewness") tailweight = tf.convert_to_tensor(tailweight, dtype=dtype, name="tailweight") f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) affine = affine_bijector.Affine(shift=loc, scale_diag=scale_diag_part, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(VectorSinhArcsinhDiag, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale_linop self._tailweight = tailweight self._skewness = skewness
def percentile(x, q, axis=None, interpolation=None, keepdims=False, validate_args=False, preserve_gradients=True, keep_dims=None, name=None): """Compute the `q`-th percentile(s) of `x`. Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the way from the minimum to the maximum in a sorted copy of `x`. The values and distances of the two nearest neighbors as well as the `interpolation` parameter will determine the percentile if the normalized ranking does not match the location of `q` exactly. This function is the same as the median if `q = 50`, the same as the minimum if `q = 0` and the same as the maximum if `q = 100`. Multiple percentiles can be computed at once by using `1-D` vector `q`. Dimension zero of the returned `Tensor` will index the different percentiles. Compare to `numpy.percentile`. Args: x: Numeric `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, `x` must have statically known number of dimensions. q: Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s). axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis that index independent samples over which to return the desired percentile. If `None` (the default), treat every dimension as a sample dimension, returning a scalar. interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}. Default value: 'nearest'. This specifies the interpolation method to use when the desired quantile lies between two data points `i < j`: * linear: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j. * lower: `i`. * higher: `j`. * nearest: `i` or `j`, whichever is nearest. * midpoint: (i + j) / 2. `linear` and `midpoint` interpolation do not work with integer dtypes. keepdims: Python `bool`. If `True`, the last dimension is kept with size 1 If `False`, the last dimension is removed from the output shape. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. preserve_gradients: Python `bool`. If `True`, ensure that gradient w.r.t the percentile `q` is preserved in the case of linear interpolation. If `False`, the gradient will be (incorrectly) zero when `q` corresponds to a point in `x`. keep_dims: deprecated, use keepdims instead. name: A Python string name to give this `Op`. Default is 'percentile' Returns: A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if `axis` is `None`, a `rank(q)` `Tensor`. The first `rank(q)` dimensions index quantiles for different values of `q`. Raises: ValueError: If argument 'interpolation' is not an allowed type. ValueError: If interpolation type not compatible with `dtype`. #### Examples ```python # Get 30th percentile with default ('nearest') interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30.) ==> 2.0 # Get 30th percentile with 'linear' interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30., interpolation='linear') ==> 1.9 # Get 30th and 70th percentiles with 'lower' interpolation x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=[30., 70.], interpolation='lower') ==> [1., 3.] # Get 100th percentile (maximum). By default, this is computed over every dim x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100.) ==> 4. # Treat the leading dim as indexing samples, and find the 100th quantile (max) # over all such samples. x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100., axis=[0]) ==> [3., 4.] ``` """ keepdims = keepdims if keep_dims is None else keep_dims del keep_dims name = name or 'percentile' allowed_interpolations = { 'linear', 'lower', 'higher', 'nearest', 'midpoint' } if interpolation is None: interpolation = 'nearest' else: if interpolation not in allowed_interpolations: raise ValueError( 'Argument `interpolation` must be in {}. Found {}.'.format( allowed_interpolations, interpolation)) with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') if (interpolation in {'linear', 'midpoint'} and dtype_util.is_integer(x.dtype)): raise TypeError( '{} interpolation not allowed with dtype {}'.format( interpolation, x.dtype)) # Double is needed here and below, else we get the wrong index if the array # is huge along axis. q = tf.cast(q, tf.float64) _get_static_ndims(q, expect_ndims_no_more_than=1) if validate_args: q = distribution_util.with_dependencies([ assert_util.assert_rank_in(q, [0, 1]), assert_util.assert_greater_equal(q, tf.cast(0., tf.float64)), assert_util.assert_less_equal(q, tf.cast(100., tf.float64)) ], q) # Move `axis` dims of `x` to the rightmost, call it `y`. if axis is None: y = tf.reshape(x, [-1]) else: x_ndims = _get_static_ndims(x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative_list(axis, x_ndims) y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True) frac_at_q_or_below = q / 100. # Sort (in ascending order) everything which allows multiple calls to sort # only once (under the hood) and use CSE. sorted_y = tf.sort(y, axis=-1, direction='ASCENDING') d = ps.cast(ps.shape(y)[-1], tf.float64) def _get_indices(interp_type): """Get values of y at the indices implied by interp_type.""" if interp_type == 'lower': indices = tf.math.floor((d - 1) * frac_at_q_or_below) elif interp_type == 'higher': indices = tf.math.ceil((d - 1) * frac_at_q_or_below) elif interp_type == 'nearest': indices = tf.round((d - 1) * frac_at_q_or_below) # d - 1 will be distinct from d in int32, but not necessarily double. # So clip to avoid out of bounds errors. return tf.clip_by_value(tf.cast(indices, tf.int32), 0, ps.shape(y)[-1] - 1) if interpolation in ['nearest', 'lower', 'higher']: gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1) elif interpolation == 'midpoint': gathered_y = 0.5 * ( tf.gather(sorted_y, _get_indices('lower'), axis=-1) + tf.gather(sorted_y, _get_indices('higher'), axis=-1)) elif interpolation == 'linear': # Copy-paste of docstring on interpolation: # linear: i + (j - i) * fraction, where fraction is the fractional part # of the index surrounded by i and j. larger_y_idx = _get_indices('higher') exact_idx = (d - 1) * frac_at_q_or_below if preserve_gradients: # If q corresponds to a point in x, we will initially have # larger_y_idx == smaller_y_idx. # This results in the gradient w.r.t. fraction being zero (recall `q` # enters only through `fraction`...and see that things cancel). # The fix is to ensure that smaller_y_idx and larger_y_idx are always # separated by exactly 1. smaller_y_idx = tf.maximum(larger_y_idx - 1, 0) larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1) fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx else: smaller_y_idx = _get_indices('lower') fraction = tf.math.ceil( (d - 1) * frac_at_q_or_below) - exact_idx fraction = tf.cast(fraction, y.dtype) gathered_y = ( tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction) # Propagate NaNs if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64): # Apparently tf.is_nan doesn't like other dtypes nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis) right_rank_matched_shape = ps.pad(ps.shape(nan_batch_members), paddings=[[0, ps.rank(q)]], constant_values=1) nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape) nan = np.array(np.nan, dtype_util.as_numpy_dtype(gathered_y.dtype)) gathered_y = tf.where(nan_batch_members, nan, gathered_y) # Expand dimensions if requested if keepdims: if axis is None: ones_vec = tf.ones(shape=[ _get_best_effort_ndims(x) + _get_best_effort_ndims(q) ], dtype=tf.int32) gathered_y *= tf.ones(ones_vec, dtype=x.dtype) else: gathered_y = _insert_back_keepdims(gathered_y, axis) # If q is a scalar, then result has the right shape. # If q is a vector, then result has trailing dim of shape q.shape, which # needs to be rotated to dim 0. return distribution_util.rotate_transpose(gathered_y, ps.rank(q))
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Default is `tfd.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.compat.v1.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if skewness is None else skewness tailweight = tf.convert_to_tensor(value=tailweight, name="tailweight", dtype=dtype) skewness = tf.convert_to_tensor(value=skewness, name="skewness", dtype=dtype) batch_shape = distribution_util.get_broadcast_shape( loc, scale, tailweight, skewness) # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: loc = distribution_util.with_dependencies(asserts, loc) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = sinh_arcsinh_bijector.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), tailweight=tailweight) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward( tf.convert_to_tensor(value=2, dtype=dtype)) affine = affine_scalar_bijector.AffineScalar( shift=loc, scale=c, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale self._tailweight = tailweight self._skewness = skewness
def _maybe_assert_valid_y(self, y): if not self.validate_args: return y is_valid = assert_util.assert_positive( y, message="Inverse transformation input must be greater than 0.") return distribution_util.with_dependencies([is_valid], y)
def __init__(self, loc, atol=None, rtol=None, is_vector=False, validate_args=False, allow_nan_stats=True, parameters=None, name="_BaseDeterministic"): """Initialize a batch of `_BaseDeterministic` distributions. The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` computations, e.g. due to floating-point error. ``` pmf(x; loc) = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), = 0, otherwise. ``` Args: loc: Numeric `Tensor`. The point (or batch of points) on which this distribution is supported. atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The absolute tolerance for comparing closeness to `loc`. Default is `0`. rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The relative tolerance for comparing closeness to `loc`. Default is `0`. is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`, else `Deterministic`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. parameters: Dict of locals to facilitate copy construction. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If `loc` is a scalar. """ with tf.compat.v2.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, atol, rtol], preferred_dtype=tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) if is_vector and validate_args: msg = "Argument loc must be at least rank 1." if loc.shape.ndims is not None: if loc.shape.ndims < 1: raise ValueError(msg) else: loc = distribution_util.with_dependencies([ assert_util.assert_rank_at_least(loc, 1, message=msg) ], loc) self._loc = loc self._atol = _get_tol(atol, self._loc.dtype, validate_args) self._rtol = _get_tol(rtol, self._loc.dtype, validate_args) super(_BaseDeterministic, self).__init__( dtype=self._loc.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._atol, self._rtol], name=name) # Avoid using the large broadcast with self.loc if possible. if rtol is None: self._slack = self.atol else: self._slack = self.atol + self.rtol * tf.abs(self.loc)
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args, dtype): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix.`diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the lower triangular matrix. `tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. dtype: `DType` for arg `Tensor` conversions. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier", dtype) diag = _as_tensor(diag, "diag", dtype) tril = _as_tensor(tril, "tril", dtype) perturb_diag = _as_tensor(perturb_diag, "perturb_diag", dtype) perturb_factor = _as_tensor(perturb_factor, "perturb_factor", dtype) # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return distribution_util.with_dependencies([ assert_util.assert_none_equal( identity_multiplier, tf.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = distribution_util.make_tril_scale( loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return tf.linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale
def __init__(self, learning_rate, preconditioner_decay_rate=0.95, data_size=1, burnin=25, diagonal_bias=1e-8, name=None, parallel_iterations=10): default_name = 'StochasticGradientLangevinDynamics' with tf.name_scope(name or default_name): if tf.executing_eagerly(): raise NotImplementedError( 'Eager execution currently not supported for ' ' SGLD optimizer.') self._preconditioner_decay_rate = tf.convert_to_tensor( preconditioner_decay_rate, name='preconditioner_decay_rate') self._data_size = tf.convert_to_tensor(data_size, name='data_size') self._burnin = tf.convert_to_tensor(burnin, name='burnin', dtype=dtype_util.common_dtype( [burnin], dtype_hint=tf.int64)) self._diagonal_bias = tf.convert_to_tensor(diagonal_bias, name='diagonal_bias') # TODO(b/124800185): Consider migrating `learning_rate` to be a # hyperparameter handled by the base Optimizer class. This would allow # users to plug in a `tf.keras.optimizers.schedules.LearningRateSchedule` # object in addition to Tensors. self._learning_rate = tf.convert_to_tensor(learning_rate, name='learning_rate') self._parallel_iterations = parallel_iterations self._preconditioner_decay_rate = distribution_util.with_dependencies( [ assert_util.assert_non_negative( self._preconditioner_decay_rate, message= '`preconditioner_decay_rate` must be non-negative'), assert_util.assert_less_equal( self._preconditioner_decay_rate, 1., message='`preconditioner_decay_rate` must be at most 1.' ), ], self._preconditioner_decay_rate) self._data_size = distribution_util.with_dependencies([ assert_util.assert_greater( self._data_size, 0, message='`data_size` must be greater than zero') ], self._data_size) self._burnin = distribution_util.with_dependencies([ assert_util.assert_non_negative( self._burnin, message='`burnin` must be non-negative'), assert_util.assert_integer( self._burnin, message='`burnin` must be an integer') ], self._burnin) self._diagonal_bias = distribution_util.with_dependencies([ assert_util.assert_non_negative( self._diagonal_bias, message='`diagonal_bias` must be non-negative') ], self._diagonal_bias) super(StochasticGradientLangevinDynamics, self).__init__(name=name or default_name)
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = tf.compat.v1.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = tf.shape(input=x) is_square = tf.compat.v1.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = tf.compat.v1.assert_positive( diag, message="Input must be positive definite.") x = distribution_util.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(input=x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze(tf.matmul( tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if x.shape.ndims is not None: if x.shape.ndims == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(input=fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector(tf.equal( tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:]) ], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def _potential_scale_reduction_single_state(state, independent_chain_ndims, split_chains, validate_args): """potential_scale_reduction for one single state `Tensor`.""" with tf.name_scope('potential_scale_reduction_single_state'): # We assume exactly one leading dimension indexes e.g. correlated samples # from each Markov chain. state = tf.convert_to_tensor(state, name='state') n_samples_ = tf.compat.dimension_value(state.shape[0]) if n_samples_ is not None: # If available statically. if split_chains and n_samples_ < 4: raise ValueError( 'Must provide at least 4 samples when splitting chains. ' 'Found {}'.format(n_samples_)) if not split_chains and n_samples_ < 2: raise ValueError( 'Must provide at least 2 samples. Found {}'.format( n_samples_)) elif validate_args: if split_chains: state = distribution_util.with_dependencies([ tf1.assert_greater( tf.shape(state)[0], 4, message= 'Must provide at least 4 samples when splitting chains.' ) ], state) else: state = distribution_util.with_dependencies([ tf1.assert_greater( tf.shape(state)[0], 2, message='Must provide at least 2 samples.') ], state) # Define so it's not a magic number. # Warning! `if split_chains` logic assumes this is 1! sample_ndims = 1 if split_chains: # Split the sample dimension in half, doubling the number of # independent chains. # For odd number of samples, keep all but the last sample. state_shape = prefer_static.shape(state) n_samples = state_shape[0] state = state[:n_samples - n_samples % 2] # Suppose state = [0, 1, 2, 3, 4, 5] # Step 1: reshape into [[0, 1, 2], [3, 4, 5]] # E.g. reshape states of shape [a, b] into [2, a//2, b]. state = tf.reshape( state, prefer_static.concat([[2, n_samples // 2], state_shape[1:]], axis=0)) # Step 2: Put the size `2` dimension in the right place to be treated as a # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]], # reshaping [2, a//2, b] into [a//2, 2, b]. state = tf.transpose( a=state, perm=prefer_static.concat( [[1, 0], tf.range(2, tf.rank(state))], axis=0)) # We're treating the new dim as indexing 2 chains, so increment. independent_chain_ndims += 1 sample_axis = tf.range(0, sample_ndims) chain_axis = tf.range(sample_ndims, sample_ndims + independent_chain_ndims) sample_and_chain_axis = tf.range( 0, sample_ndims + independent_chain_ndims) n = _axis_size(state, sample_axis) m = _axis_size(state, chain_axis) # In the language of Brooks and Gelman (1998), # B / n is the between chain variance, the variance of the chain means. # W is the within sequence variance, the mean of the chain variances. b_div_n = _reduce_variance(tf.reduce_mean(state, axis=sample_axis, keepdims=True), sample_and_chain_axis, biased=False) w = tf.reduce_mean(_reduce_variance(state, sample_axis, keepdims=True, biased=True), axis=sample_and_chain_axis) # sigma^2_+ is an estimate of the true variance, which would be unbiased if # each chain was drawn from the target. c.f. "law of total variance." sigma_2_plus = w + b_div_n return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
def __init__(self, permutation, axis=-1, validate_args=False, name=None): """Creates the `Permute` bijector. Args: permutation: An `int`-like vector-shaped `Tensor` representing the permutation to apply to the `axis` dimension of the transformed `Tensor`. axis: Scalar `int` `Tensor` representing the dimension over which to `tf.gather`. `axis` must be relative to the end (reading left to right) thus must be negative. Default value: `-1` (i.e., right-most). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Raises: TypeError: if `not permutation.dtype.is_integer`. ValueError: if `permutation` does not contain exactly one of each of `{0, 1, ..., d}`. NotImplementedError: if `axis` is not known prior to graph execution. NotImplementedError: if `axis` is not negative. """ with tf.compat.v1.name_scope(name, "permute", values=[permutation, axis]): axis = tf.convert_to_tensor(value=axis, name="axis") if not axis.dtype.is_integer: raise TypeError("axis.dtype ({}) should be `int`-like.".format( axis.dtype.name)) permutation = tf.convert_to_tensor(value=permutation, name="permutation") if not permutation.dtype.is_integer: raise TypeError( "permutation.dtype ({}) should be `int`-like.".format( permutation.dtype.name)) p = tf.get_static_value(permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.") elif validate_args: p, _ = tf.nn.top_k(-permutation, k=tf.shape(input=permutation)[-1], sorted=True) permutation = distribution_util.with_dependencies([ tf.compat.v1.assert_equal( -p, tf.range(tf.size(input=p)), message=( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.")), ], permutation) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( "`axis` must be known prior to graph " "execution.") elif axis_ >= 0: raise NotImplementedError( "`axis` must be relative the rightmost " "dimension, i.e., negative.") else: forward_min_event_ndims = int(np.abs(axis_)) self._permutation = permutation self._axis = axis super(Permute, self).__init__( forward_min_event_ndims=forward_min_event_ndims, is_constant_jacobian=True, validate_args=validate_args, name=name or "permute")
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(value=df, dtype=scale_operator.dtype, name="df") tf.debugging.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( value=tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=([self._df, self._dimension] + self._scale_operator.graph_parents), name=name)
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(value=df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(value=scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(input=scale_tril)[-1], tf.shape(input=scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, batch_size, total_num_examples, max_learning_rate=1., preconditioner_decay_rate=0.95, burnin=25, burnin_max_learning_rate=1e-6, use_single_learning_rate=False, name=None): default_name = 'VariationalSGD' with tf.compat.v1.name_scope(name, default_name, [ max_learning_rate, preconditioner_decay_rate, batch_size, burnin, burnin_max_learning_rate ]): self._preconditioner_decay_rate = tf.convert_to_tensor( value=preconditioner_decay_rate, name='preconditioner_decay_rate') self._batch_size = tf.convert_to_tensor( value=batch_size, name='batch_size') self._total_num_examples = tf.convert_to_tensor( value=total_num_examples, name='total_num_examples') self._burnin = tf.convert_to_tensor( value=burnin, name='burnin', dtype=dtype_util.common_dtype([burnin], preferred_dtype=tf.int64)) self._burnin_max_learning_rate = tf.convert_to_tensor( value=burnin_max_learning_rate, name='burnin_max_learning_rate') self._max_learning_rate = tf.convert_to_tensor( value=max_learning_rate, name='max_learning_rate') self._use_single_learning_rate = use_single_learning_rate self._preconditioner_decay_rate = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._preconditioner_decay_rate, message='`preconditioner_decay_rate` must be non-negative'), tf.compat.v1.assert_less_equal( self._preconditioner_decay_rate, 1., message='`preconditioner_decay_rate` must be at most 1.'), ], self._preconditioner_decay_rate) self._batch_size = distribution_util.with_dependencies([ tf.compat.v1.assert_greater( self._batch_size, 0, message='`batch_size` must be greater than zero') ], self._batch_size) self._total_num_examples = distribution_util.with_dependencies([ tf.compat.v1.assert_greater( self._total_num_examples, 0, message='`total_num_examples` must be greater than zero') ], self._total_num_examples) self._burnin = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._burnin, message='`burnin` must be non-negative'), tf.compat.v1.assert_integer( self._burnin, message='`burnin` must be an integer') ], self._burnin) self._burnin_max_learning_rate = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._burnin_max_learning_rate, message='`burnin_max_learning_rate` must be non-negative') ], self._burnin_max_learning_rate) self._max_learning_rate = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._max_learning_rate, message='`max_learning_rate` must be non-negative') ], self._max_learning_rate) super(VariationalSGD, self).__init__(name=name or default_name)
def __init__(self, mix_loc, temperature, distribution, loc=None, scale=None, quadrature_size=8, quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R^d`. The vector diffeomixture (VDM) approximates the compound distribution ```none p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` Args: mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. In terms of samples, larger `mix_loc[..., k]` ==> `Z` is more likely to put more weight on its `kth` component. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. In terms of samples, smaller `temperature` means one component is more likely to dominate. I.e., smaller `temperature` makes the VDM look more like a standard mixture of `K` components. distribution: `tfp.distributions.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element represents the `shift` used for the `k`-th affine transformation. If the `k`-th item is `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event size. scale: Length-`K` list of `LinearOperator`s. Each should be positive-definite and operate on a `d`-dimensional vector space. The `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of quadrature points. Larger `quadrature_size` means `q_N(x)` better approximates `p(x)`. quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. NotImplementedError: if `len(scale) != 2`. ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ parameters = dict(locals()) with tf.compat.v2.name_scope(name) as name: if not scale or len(scale) < 2: raise ValueError( "Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " "num_component >= 2.") if loc is None: loc = [None] * len(scale) if len(loc) != len(scale): raise ValueError("loc/scale must be same-length lists " "(or same-length list-like objects).") dtype = scale[0].dtype.base_dtype loc = [ tf.convert_to_tensor( value=loc_, dtype=dtype, name="loc{}".format(k)) if loc_ is not None else None for k, loc_ in enumerate(loc) ] for k, scale_ in enumerate(scale): if validate_args and not scale_.is_positive_definite: raise ValueError( "scale[{}].is_positive_definite = {} != True".format( k, scale_.is_positive_definite)) if scale_.dtype.base_dtype != dtype: raise TypeError( "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"" .format(k, scale_.dtype.base_dtype.name, dtype.name)) self._endpoint_affine = [ affine_linear_operator_bijector.AffineLinearOperator( shift=loc_, scale=scale_, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale)) ] # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. if len(scale) != 2: raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) mix_loc = tf.convert_to_tensor(value=mix_loc, dtype=dtype, name="mix_loc") temperature = tf.convert_to_tensor(value=temperature, dtype=dtype, name="temperature") self._grid, probs = tuple( quadrature_fn(mix_loc / temperature, 1. / temperature, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical.Categorical( logits=tf.math.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._grid = distribution_util.with_dependencies( asserts, self._grid) self._distribution = distribution self._interpolated_affine = [ affine_linear_operator_bijector.AffineLinearOperator( shift=loc_, scale=scale_, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate( zip(interpolate_loc(self._grid, loc), interpolate_scale(self._grid, scale))) ] [ self._batch_shape_, self._batch_shape_tensor_, self._event_shape_, self._event_shape_tensor_, ] = determine_batch_event_shapes(self._grid, self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, # We hard-code `FULLY_REPARAMETERIZED` because when # `validate_args=True` we verify that indeed # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A # distribution which is a function of only non-trainable parameters # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot # easily test for that possibility thus we use `validate_args=False` # as a "back-door" to allow users a way to use non # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS # RESPONSIBILITY to verify that the base distribution is a function of # non-trainable parameters. reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( distribution._graph_parents # pylint: disable=protected-access + [loc_ for loc_ in loc if loc_ is not None] + [p for scale_ in scale for p in scale_.graph_parents]), name=name)
def _maybe_assert_valid(self, t): if not self.validate_args: return t is_valid = tf.compat.v1.assert_none_equal( t, 0., message="All elements must be non-zero.") return distribution_util.with_dependencies([is_valid], t)
def _maybe_assert_valid_x(self, x): if not self.validate_args: return x is_valid = assert_util.assert_non_negative( x, message="Forward transformation input must be at least 0.") return distribution_util.with_dependencies([is_valid], x)
def __init__(self, learning_rate, preconditioner_decay_rate=0.95, data_size=1, burnin=25, diagonal_bias=1e-8, name=None, parallel_iterations=10, variable_scope=None): default_name = 'StochasticGradientLangevinDynamics' with tf.name_scope(name, default_name, [ learning_rate, preconditioner_decay_rate, data_size, burnin, diagonal_bias ]): if tf.executing_eagerly(): raise NotImplementedError( 'Eager execution currently not supported for ' ' SGLD optimizer.') if variable_scope is None: var_scope_name = tf.compat.v1.get_default_graph().unique_name( name or default_name) with tf.compat.v1.variable_scope(var_scope_name) as scope: self._variable_scope = scope else: self._variable_scope = variable_scope self._preconditioner_decay_rate = tf.convert_to_tensor( value=preconditioner_decay_rate, name='preconditioner_decay_rate') self._data_size = tf.convert_to_tensor(value=data_size, name='data_size') self._burnin = tf.convert_to_tensor(value=burnin, name='burnin') self._diagonal_bias = tf.convert_to_tensor(value=diagonal_bias, name='diagonal_bias') self._learning_rate = tf.convert_to_tensor(value=learning_rate, name='learning_rate') self._parallel_iterations = parallel_iterations with tf.compat.v1.variable_scope(self._variable_scope): self._counter = tf.compat.v1.get_variable('counter', initializer=0, trainable=False) self._preconditioner_decay_rate = distribution_util.with_dependencies( [ tf.compat.v1.assert_non_negative( self._preconditioner_decay_rate, message= '`preconditioner_decay_rate` must be non-negative'), tf.compat.v1.assert_less_equal( self._preconditioner_decay_rate, 1., message='`preconditioner_decay_rate` must be at most 1.' ), ], self._preconditioner_decay_rate) self._data_size = distribution_util.with_dependencies([ tf.compat.v1.assert_greater( self._data_size, 0, message='`data_size` must be greater than zero') ], self._data_size) self._burnin = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._burnin, message='`burnin` must be non-negative'), tf.compat.v1.assert_integer( self._burnin, message='`burnin` must be an integer') ], self._burnin) self._diagonal_bias = distribution_util.with_dependencies([ tf.compat.v1.assert_non_negative( self._diagonal_bias, message='`diagonal_bias` must be non-negative') ], self._diagonal_bias) super(StochasticGradientLangevinDynamics, self).__init__(use_locking=False, name=name or default_name)