def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError( "Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value(param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def _variance(self): df = tf.convert_to_tensor(self.df) scale = tf.convert_to_tensor(self.scale) # We need to put the tf.where inside the outer tf.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) # Abs(scale) superfluous. var = (tf.ones(self._batch_shape_tensor(df=df, scale=scale), dtype=self.dtype) * tf.square(scale) * df / denom) # When 1 < df <= 2, variance is infinite. result_where_defined = tf.where( df > 2., var, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where( df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=self.dtype), df, message='variance not defined for components of df <= 1'), ], result_where_defined)
def _maybe_validate_shape_override(self, override_shape, base_is_scalar, validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" if override_shape is None: override_shape = [] override_shape = tf.convert_to_tensor(override_shape, dtype=tf.int32, name=name) if not dtype_util.is_integer(override_shape.dtype): raise TypeError("shape override must be an integer") override_is_scalar = _is_scalar_from_shape_tensor(override_shape) if tf.get_static_value(override_is_scalar): return self._empty dynamic_assertions = [] if tensorshape_util.rank(override_shape.shape) is not None: if tensorshape_util.rank(override_shape.shape) != 1: raise ValueError("shape override must be a vector") elif validate_args: dynamic_assertions += [ assert_util.assert_rank( override_shape, 1, message="shape override must be a vector") ] if tf.get_static_value(override_shape) is not None: if any(s < 0 for s in tf.get_static_value(override_shape)): raise ValueError( "shape override must have non-negative elements") elif validate_args: dynamic_assertions += [ assert_util.assert_non_negative( override_shape, message="shape override must have non-negative elements") ] is_both_nonscalar = prefer_static.logical_and( prefer_static.logical_not(base_is_scalar), prefer_static.logical_not(override_is_scalar)) if tf.get_static_value(is_both_nonscalar) is not None: if tf.get_static_value(is_both_nonscalar): raise ValueError("base distribution not scalar") elif validate_args: dynamic_assertions += [ assert_util.assert_equal( is_both_nonscalar, False, message="base distribution not scalar") ] if not dynamic_assertions: return override_shape return distribution_util.with_dependencies(dynamic_assertions, override_shape)
def _assert_valid_sample(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_positive(x), assert_util.assert_near(tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])), ], x)
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _check_integer(self, value): with tf.name_scope("check_integer"): value = tf.convert_to_tensor(value, name="value") if not self.validate_args: return value dependencies = [ distribution_util.assert_integer_form( value, message="value has non-integer components.") ] return distribution_util.with_dependencies(dependencies, value)
def _maybe_assert_valid(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_negative( x, message="sample must be non-negative"), assert_util.assert_less_equal( x, tf.ones([], self.concentration0.dtype), message="sample must be no larger than `1`."), ], x)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) msg = ('Sampled counts must be itemwise less than ' 'or equal to `total_count` parameter.') return distribution_util.with_dependencies([ assert_util.assert_less_equal( counts, self.total_count, message=msg), ], counts)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) return distribution_util.with_dependencies([ assert_util.assert_equal( self.total_count, tf.reduce_sum(counts, axis=-1), message='counts last-dimension must sum to `self.total_count`'), ], counts)
def __init__(self, temperature, validate_args=False, name="softfloor"): with tf.name_scope(name) as name: self._temperature = tf.convert_to_tensor(temperature, name="temperature") if validate_args: self._temperature = distribution_util.with_dependencies([ assert_util.assert_positive( self._temperature, message="Argument temperature was not positive") ], self._temperature) super(Softfloor, self).__init__(forward_min_event_ndims=0, validate_args=validate_args, dtype=self._temperature.dtype, name=name)
def _mean(self): df = tf.convert_to_tensor(self.df) loc = tf.convert_to_tensor(self.loc) mean = loc * tf.ones(self._batch_shape_tensor(loc=loc), dtype=self.dtype) if self.allow_nan_stats: return tf.where( df > 1., mean, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=self.dtype), df, message='mean not defined for components of df <= 1'), ], mean)
def _mode(self): a = tf.convert_to_tensor(self.concentration1) b = tf.convert_to_tensor(self.concentration0) mode = ((a - 1) / (a * b - 1))**(1. / a) if self.allow_nan_stats: return tf.where((a > 1.) & (b > 1.), mode, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=a.dtype), a, message="Mode undefined for concentration1 <= 1."), assert_util.assert_less( tf.ones([], dtype=b.dtype), b, message="Mode undefined for concentration0 <= 1.") ], mode)
def __init__(self, loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalFullCovariance"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `covariance_matrix` arguments. The `event_shape` is given by last dimension of the matrix implied by `covariance_matrix`. The last dimension of `loc` (if provided) must broadcast with this. A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive definite matrix. In other words it is (real) symmetric with all eigenvalues strictly positive. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. covariance_matrix: Floating-point, symmetric positive definite `Tensor` of same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` is ignored, so if `covariance_matrix` is not symmetric no error will be raised (unless `validate_args is True`). `covariance_matrix` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with tf.name_scope(name) as name: with tf.name_scope("init"): dtype = dtype_util.common_dtype([loc, covariance_matrix], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=dtype) if covariance_matrix is None: scale_tril = None else: covariance_matrix = tf.convert_to_tensor( covariance_matrix, name="covariance_matrix", dtype=dtype) if validate_args: covariance_matrix = distribution_util.with_dependencies( [ assert_util.assert_near( covariance_matrix, tf.linalg.matrix_transpose( covariance_matrix), message="Matrix was not symmetric") ], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = tf.linalg.cholesky(covariance_matrix) super(MultivariateNormalFullCovariance, self).__init__(loc=loc, scale_tril=scale_tril, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args, dtype): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix.`diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the lower triangular matrix. `tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. dtype: `DType` for arg `Tensor` conversions. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier", dtype) diag = _as_tensor(diag, "diag", dtype) tril = _as_tensor(tril, "tril", dtype) perturb_diag = _as_tensor(perturb_diag, "perturb_diag", dtype) perturb_factor = _as_tensor(perturb_factor, "perturb_factor", dtype) # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return distribution_util.with_dependencies([ assert_util.assert_none_equal( identity_multiplier, tf.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = distribution_util.make_tril_scale( loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return tf.linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(df, dtype=scale_operator.dtype, name="df") dtype_util.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, name=name)
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(scale_tril)[-1], tf.shape(scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, mix_loc, temperature, distribution, loc=None, scale=None, quadrature_size=8, quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R^d`. The vector diffeomixture (VDM) approximates the compound distribution ```none p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` Args: mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. In terms of samples, larger `mix_loc[..., k]` ==> `Z` is more likely to put more weight on its `kth` component. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. In terms of samples, smaller `temperature` means one component is more likely to dominate. I.e., smaller `temperature` makes the VDM look more like a standard mixture of `K` components. distribution: `tfp.distributions.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element represents the `shift` used for the `k`-th affine transformation. If the `k`-th item is `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event size. scale: Length-`K` list of `LinearOperator`s. Each should be positive-definite and operate on a `d`-dimensional vector space. The `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of quadrature points. Larger `quadrature_size` means `q_N(x)` better approximates `p(x)`. quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. NotImplementedError: if `len(scale) != 2`. ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not scale or len(scale) < 2: raise ValueError( "Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " "num_component >= 2.") if loc is None: loc = [None] * len(scale) if len(loc) != len(scale): raise ValueError("loc/scale must be same-length lists " "(or same-length list-like objects).") dtype = dtype_util.base_dtype(scale[0].dtype) loc = [ tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k)) if loc_ is not None else None for k, loc_ in enumerate(loc) ] for k, scale_ in enumerate(scale): if validate_args and not scale_.is_positive_definite: raise ValueError( "scale[{}].is_positive_definite = {} != True".format( k, scale_.is_positive_definite)) if dtype_util.base_dtype(scale_.dtype) != dtype: raise TypeError( "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"" .format(k, dtype_util.name(scale_.dtype), dtype_util.name(dtype))) self._endpoint_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale)) ] # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. if len(scale) != 2: raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) mix_loc = tf.convert_to_tensor(mix_loc, dtype=dtype, name="mix_loc") temperature = tf.convert_to_tensor(temperature, dtype=dtype, name="temperature") self._grid, probs = tuple( quadrature_fn(mix_loc / temperature, 1. / temperature, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical.Categorical( logits=tf.math.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._grid = distribution_util.with_dependencies( asserts, self._grid) self._distribution = distribution self._interpolated_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate( zip(interpolate_loc(self._grid, loc), interpolate_scale(self._grid, scale))) ] [ self._batch_shape_, self._batch_shape_tensor_, self._event_shape_, self._event_shape_tensor_, ] = determine_batch_event_shapes(self._grid, self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, # We hard-code `FULLY_REPARAMETERIZED` because when # `validate_args=True` we verify that indeed # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A # distribution which is a function of only non-trainable parameters # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot # easily test for that possibility thus we use `validate_args=False` # as a "back-door" to allow users a way to use non # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS # RESPONSIBILITY to verify that the base distribution is a function of # non-trainable parameters. reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="VectorSinhArcsinhDiag"): """Construct VectorSinhArcsinhDiag distribution on `R^k`. The arguments `scale_diag` and `scale_identity_multiplier` combine to define the diagonal `scale` referred to in this class docstring: ```none scale = diag(scale_diag + scale_identity_multiplier * ones(k)) ``` The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scale-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. skewness: Skewness parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. tailweight: Tailweight parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is `tfd.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorSinhArcsinhDiag sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ loc, scale_diag, scale_identity_multiplier, skewness, tailweight ], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=dtype) tailweight = 1. if tailweight is None else tailweight skewness = 0. if skewness is None else skewness # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: # 1. get shapes from looking at loc and the two scale args. # 2. combine scale_diag with scale_identity_multiplier, which gives us # 'scale', which in turn gives us 'C'. scale_linop = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=False, assert_positive=False, dtype=dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale_linop) # scale_linop.diag_part() is efficient since it is a diag type linop. scale_diag_part = scale_linop.diag_part() dtype = scale_diag_part.dtype if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: scale_diag_part = distribution_util.with_dependencies( asserts, scale_diag_part) # Make the SAS bijector, 'F'. skewness = tf.convert_to_tensor(skewness, dtype=dtype, name="skewness") tailweight = tf.convert_to_tensor(tailweight, dtype=dtype, name="tailweight") f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) affine = affine_bijector.Affine(shift=loc, scale_diag=scale_diag_part, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(VectorSinhArcsinhDiag, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale_linop self._tailweight = tailweight self._skewness = skewness
def __init__(self, permutation, axis=-1, validate_args=False, name=None): """Creates the `Permute` bijector. Args: permutation: An `int`-like vector-shaped `Tensor` representing the permutation to apply to the `axis` dimension of the transformed `Tensor`. axis: Scalar `int` `Tensor` representing the dimension over which to `tf.gather`. `axis` must be relative to the end (reading left to right) thus must be negative. Default value: `-1` (i.e., right-most). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Raises: TypeError: if `not dtype_util.is_integer(permutation.dtype)`. ValueError: if `permutation` does not contain exactly one of each of `{0, 1, ..., d}`. NotImplementedError: if `axis` is not known prior to graph execution. NotImplementedError: if `axis` is not negative. """ with tf.name_scope(name or "permute") as name: axis = tf.convert_to_tensor(axis, name="axis") if not dtype_util.is_integer(axis.dtype): raise TypeError("axis.dtype ({}) should be `int`-like.".format( dtype_util.name(axis.dtype))) permutation = tf.convert_to_tensor(permutation, name="permutation") if not dtype_util.is_integer(permutation.dtype): raise TypeError( "permutation.dtype ({}) should be `int`-like.".format( dtype_util.name(permutation.dtype))) p = tf.get_static_value(permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.") elif validate_args: p, _ = tf.math.top_k(-permutation, k=tf.shape(permutation)[-1], sorted=True) permutation = distribution_util.with_dependencies([ assert_util.assert_equal( -p, tf.range(tf.size(p)), message=( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.")), ], permutation) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( "`axis` must be known prior to graph " "execution.") elif axis_ >= 0: raise NotImplementedError( "`axis` must be relative the rightmost " "dimension, i.e., negative.") else: forward_min_event_ndims = int(np.abs(axis_)) self._permutation = permutation self._axis = axis super(Permute, self).__init__( forward_min_event_ndims=forward_min_event_ndims, is_constant_jacobian=True, validate_args=validate_args, name=name)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Default is `tfd.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor(loc, name="loc", dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name="tailweight", dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name="skewness", dtype=dtype) batch_shape = distribution_util.get_broadcast_shape( self._loc, self._scale, self._tailweight, self._skewness) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._loc = distribution_util.with_dependencies( asserts, self._loc) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = affine_scalar_bijector.AffineScalar( shift=self._loc, scale=self._scale, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, validate_args=validate_args, name=name) self._parameters = parameters