def _max_precision_sum(a, b): """Coerces `a` or `b` to the higher-precision dtype, and returns the sum.""" if not dtype_util.base_equal(a.dtype, b.dtype): if dtype_util.size(a.dtype) >= dtype_util.size(b.dtype): b = tf.cast(b, a.dtype) else: a = tf.cast(a, b.dtype) return a + b
def _maybe_assert_dtype(self, x): """Helper to check dtype when self.dtype is known.""" if SKIP_DTYPE_CHECKS: return if (self.dtype is not None and not dtype_util.base_equal(self.dtype, x.dtype)): raise TypeError( 'Input had dtype %s but expected %s.' % (x.dtype, self.dtype))
def __init__(self, shift=None, scale=None, adjoint=False, validate_args=False, name="affine_linear_operator"): """Instantiates the `AffineLinearOperator` bijector. Args: shift: Floating-point `Tensor`. scale: Subclass of `LinearOperator`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. adjoint: Python `bool` indicating whether to use the `scale` matrix as specified or its adjoint. Default value: `False`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: TypeError: if `scale` is not a `LinearOperator`. TypeError: if `shift.dtype` does not match `scale.dtype`. ValueError: if not `scale.is_non_singular`. """ with tf.name_scope(name) as name: # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = tf.float32 if shift is not None: shift = tf.convert_to_tensor(value=shift, name="shift") dtype = dtype_util.base_dtype(shift.dtype) self._shift = shift if scale is not None: if (shift is not None and not dtype_util.base_equal(shift.dtype, scale.dtype)): raise TypeError( "shift.dtype({}) is incompatible with scale.dtype({})." .format(shift.dtype, scale.dtype)) if not isinstance(scale, tf.linalg.LinearOperator): raise TypeError( "scale is not an instance of tf.LinearOperator") if validate_args and not scale.is_non_singular: raise ValueError("Scale matrix must be non-singular.") if scale.dtype is not None: dtype = dtype_util.base_dtype(scale.dtype) self._scale = scale self._adjoint = adjoint super(AffineLinearOperator, self).__init__(forward_min_event_ndims=1, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name)
def poisson_and_mixture_distributions(self): """Returns the Poisson and Mixture distribution parameterized by the quadrature grid and weights.""" loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) quadrature_grid, quadrature_probs = tuple(self._quadrature_fn( loc, scale, self.quadrature_size, self.validate_args)) dt = quadrature_grid.dtype if not dtype_util.base_equal(dt, quadrature_probs.dtype): raise TypeError('Quadrature grid dtype ({}) does not match quadrature ' 'probs dtype ({}).'.format( dtype_util.name(dt), dtype_util.name(quadrature_probs.dtype))) dist = poisson.Poisson( log_rate=quadrature_grid, validate_args=self.validate_args, allow_nan_stats=self.allow_nan_stats) mixture_dist = categorical.Categorical( logits=tf.math.log(quadrature_probs), validate_args=self.validate_args, allow_nan_stats=self.allow_nan_stats) return dist, mixture_dist
def __init__(self, power, dtype=tf.int32, interpolate_nondiscrete=True, sample_maximum_iterations=100, validate_args=False, allow_nan_stats=False, name='Zipf'): """Initialize a batch of Zipf distributions. Args: power: `Float` like `Tensor` representing the power parameter. Must be strictly greater than `1`. dtype: The `dtype` of `Tensor` returned by `sample`. Default value: `tf.int32`. interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`, `log_prob` evaluates the continuous function `-power log(k) - log(zeta(power))` , which matches the Zipf pmf at integer arguments `k` (note that this function is not itself a normalized probability log-density). Default value: `True`. sample_maximum_iterations: Maximum number of iterations of allowable iterations in `sample`. When `validate_args=True`, samples which fail to reach convergence (subject to this cap) are masked out with `self.dtype.min` or `nan` depending on `self.dtype.is_integer`. Default value: `100`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Zipf'`. Raises: TypeError: if `power` is not `float` like. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._power = tensor_util.convert_nonref_to_tensor( power, name='power', dtype=dtype_util.common_dtype([power], dtype_hint=tf.float32)) if (not dtype_util.is_floating(self._power.dtype) or dtype_util.base_equal(self._power.dtype, tf.float16)): raise TypeError( 'power.dtype ({}) is not a supported `float` type.'.format( dtype_util.name(self._power.dtype))) self._interpolate_nondiscrete = interpolate_nondiscrete self._sample_maximum_iterations = sample_maximum_iterations super(Zipf, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, shift=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, adjoint=False, validate_args=False, name="affine", dtype=None): """Instantiates the `Affine` bijector. This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, giving the forward operation: ```none Y = g(X) = scale @ X + shift ``` where the `scale` term is logically equivalent to: ```python scale = ( scale_identity_multiplier * tf.diag(tf.ones(d)) + tf.diag(scale_diag) + scale_tril + scale_perturb_factor @ diag(scale_perturb_diag) @ tf.transpose([scale_perturb_factor]) ) ``` If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are specified then `scale += IdentityMatrix`. Otherwise specifying a `scale` argument has the semantics of `scale += Expand(arg)`, i.e., `scale_diag != None` means `scale += tf.diag(scale_diag)`. Args: shift: Floating-point `Tensor`. If this is set to `None`, no shift is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag = scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the lower triangular matrix. `scale_tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape `[N1, N2, ... r]`, which represents an `r x r` diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. adjoint: Python `bool` indicating whether to use the `scale` matrix as specified or its adjoint. Default value: `False`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we fall back to a common dtype inferred from the args, finally falling back to float32. Raises: ValueError: if `perturb_diag` is specified but not `perturb_factor`. TypeError: if `shift` has different `dtype` from `scale` arguments. """ # Ambiguous definition of low rank update. if scale_perturb_diag is not None and scale_perturb_factor is None: raise ValueError("When scale_perturb_diag is specified, " "scale_perturb_factor must be specified.") # Special case, only handling a scaled identity matrix. We don't know its # dimensions, so this is special cased. # We don't check identity_multiplier, since below we set it to 1. if all # other scale args are None. self._is_only_identity_multiplier = (scale_tril is None and scale_diag is None and scale_perturb_factor is None) with tf.name_scope(name) as name: self._name = name self._validate_args = validate_args if dtype is None: dtype = dtype_util.common_dtype([ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor ], tf.float32) if shift is not None: shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype) self._shift = shift # When no args are specified, pretend the scale matrix is the identity # matrix. if (self._is_only_identity_multiplier and scale_identity_multiplier is None): scale_identity_multiplier = tf.convert_to_tensor(1., dtype=dtype) # self._create_scale_operator returns a LinearOperator in all cases # except if self._is_only_identity_multiplier; in which case it # returns a scalar Tensor. scale = self._create_scale_operator( identity_multiplier=scale_identity_multiplier, diag=scale_diag, tril=scale_tril, perturb_diag=scale_perturb_diag, perturb_factor=scale_perturb_factor, shift=shift, validate_args=validate_args, dtype=dtype) if (scale is not None and not self._is_only_identity_multiplier and not dtype_util.SKIP_DTYPE_CHECKS): if (shift is not None and not dtype_util.base_equal(shift.dtype, scale.dtype)): raise TypeError( "shift.dtype ({}) is incompatible with scale.dtype ({})." .format(shift.dtype, scale.dtype)) self._scale = scale self._adjoint = adjoint super(Affine, self).__init__(forward_min_event_ndims=1, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name)
def log_ndtr(x, series_order=3, name="log_ndtr"): """Log Normal distribution function. For details of the Normal distribution function see `ndtr`. This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or using an asymptotic series. Specifically: - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on `log(1-x) ~= -x, x << 1`. - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique and take a log. - For `x <= lower_segment`, we use the series approximation of erf to compute the log CDF directly. The `lower_segment` is set based on the precision of the input: ``` lower_segment = { -20, x.dtype=float64 { -10, x.dtype=float32 upper_segment = { 8, x.dtype=float64 { 5, x.dtype=float32 ``` When `x < lower_segment`, the `ndtr` asymptotic series approximation is: ``` ndtr(x) = scale * (1 + sum) + R_N scale = exp(-0.5 x**2) / (-x sqrt(2 pi)) sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N} R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3}) ``` where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a [double-factorial](https://en.wikipedia.org/wiki/Double_factorial). Args: x: `Tensor` of type `float32`, `float64`. series_order: Positive Python `integer`. Maximum depth to evaluate the asymptotic expansion. This is the `N` above. name: Python string. A name for the operation (default="log_ndtr"). Returns: log_ndtr: `Tensor` with `dtype=x.dtype`. Raises: TypeError: if `x.dtype` is not handled. TypeError: if `series_order` is a not Python `integer.` ValueError: if `series_order` is not in `[0, 30]`. """ if not isinstance(series_order, int): raise TypeError("series_order must be a Python integer.") if series_order < 0: raise ValueError("series_order must be non-negative.") if series_order > 30: raise ValueError("series_order must be <= 30.") with tf.name_scope(name): x = tf.convert_to_tensor(x, name="x") if dtype_util.base_equal(x.dtype, tf.float64): lower_segment = np.array(LOGNDTR_FLOAT64_LOWER, np.float64) upper_segment = np.array(LOGNDTR_FLOAT64_UPPER, np.float64) elif dtype_util.base_equal(x.dtype, tf.float32): lower_segment = np.array(LOGNDTR_FLOAT32_LOWER, np.float32) upper_segment = np.array(LOGNDTR_FLOAT32_UPPER, np.float32) else: raise TypeError("x.dtype=%s is not supported." % x.dtype) # The basic idea here was ported from: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # We copy the main idea, with a few changes # * For x >> 1, and X ~ Normal(0, 1), # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], # which extends the range of validity of this function. # * We use one fixed series_order for all of 'x', rather than adaptive. # * Our docstring properly reflects that this is an asymptotic series, not a # Taylor series. We also provided a correct bound on the remainder. # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when # x=0. This happens even though the branch is unchosen because when x=0 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. return tf.where( x > upper_segment, -_ndtr(-x), # log(1-x) ~= -x, x << 1 tf.where( x > lower_segment, tf.math.log(_ndtr(tf.maximum(x, lower_segment))), _log_ndtr_lower(tf.minimum(x, lower_segment), series_order)))
def __init__(self, loc, scale, quadrature_size=8, quadrature_fn=quadrature_scheme_lognormal_quantiles, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): """Constructs the PoissonLogNormalQuadratureCompound`. Note: `probs` returned by (optional) `quadrature_fn` are presumed to be either a length-`quadrature_size` vector or a batch of vectors in 1-to-1 correspondence with the returned `grid`. (I.e., broadcasting is only partially supported.) Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. quadrature_fn: Python callable taking `loc`, `scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the LogNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_lognormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], tf.float32) if loc is not None: loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, dtype=dtype, name="scale") self._quadrature_grid, self._quadrature_probs = tuple( quadrature_fn(loc, scale, quadrature_size, validate_args)) dt = self._quadrature_grid.dtype if not dtype_util.base_equal(dt, self._quadrature_probs.dtype): raise TypeError( "Quadrature grid dtype ({}) does not match quadrature " "probs dtype ({}).".format( dtype_util.name(dt), dtype_util.name(self._quadrature_probs.dtype))) self._distribution = poisson.Poisson( log_rate=self._quadrature_grid, validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._mixture_distribution = categorical.Categorical( logits=tf.math.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._loc = loc self._scale = scale self._quadrature_size = quadrature_size super(PoissonLogNormalQuadratureCompound, self).__init__( dtype=dt, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[loc, scale], name=name)