def _std_var_helper(self, statistic, statistic_name, statistic_ndims, df_factor_fn): """Helper to compute stddev, covariance and variance.""" df = tf.reshape( self.df, tf.concat([ tf.shape(self.df), tf.ones([statistic_ndims], dtype=tf.int32) ], -1)) # We need to put the tf.where inside the outer tf1.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) statistic = statistic * df_factor_fn(df / denom) # When 1 < df <= 2, stddev/variance are infinite. result_where_defined = tf.where( df > 2., statistic, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), df, message='{} not defined for components of df <= 1.'. format(statistic_name.capitalize())), ]): return tf.identity(result_where_defined)
def _uniform_unit_norm(dimension, shape, dtype, seed): """Returns a batch of points chosen uniformly from the unit hypersphere.""" # This works because the Gaussian distribution is spherically symmetric. # raw shape: shape + [dimension] raw = normal.Normal(loc=dtype_util.as_numpy_dtype(dtype)(0), scale=dtype_util.as_numpy_dtype(dtype)(1)).sample( tf.concat([shape, [dimension]], axis=0), seed=seed()) unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis] return unit_norm
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_greater( t, dtype_util.as_numpy_dtype(t.dtype)(-1), message="Inverse transformation input must be greater than -1." ), assert_util.assert_less( t, dtype_util.as_numpy_dtype(t.dtype)(1), message="Inverse transformation input must be less than 1.") ]
def ndtri(p, name="ndtri"): """The inverse of the CDF of the Normal distribution function. Returns x such that the area under the pdf from minus infinity to x is equal to p. A piece-wise rational approximation is done for the function. This is a port of the implementation in netlib. Args: p: `Tensor` of type `float32`, `float64`. name: Python string. A name for the operation (default="ndtri"). Returns: x: `Tensor` with `dtype=p.dtype`. Raises: TypeError: if `p` is not floating-type. """ with tf.name_scope(name): p = tf.convert_to_tensor(p, name="p") if dtype_util.as_numpy_dtype(p.dtype) not in [np.float32, np.float64]: raise TypeError( "p.dtype=%s is not handled, see docstring for supported types." % p.dtype) return _ndtri(p)
def ndtr(x, name="ndtr"): """Normal distribution function. Returns the area under the Gaussian probability density function, integrated from minus infinity to x: ``` 1 / x ndtr(x) = ---------- | exp(-0.5 t**2) dt sqrt(2 pi) /-inf = 0.5 (1 + erf(x / sqrt(2))) = 0.5 erfc(x / sqrt(2)) ``` Args: x: `Tensor` of type `float32`, `float64`. name: Python string. A name for the operation (default="ndtr"). Returns: ndtr: `Tensor` with `dtype=x.dtype`. Raises: TypeError: if `x` is not floating-type. """ with tf.name_scope(name): x = tf.convert_to_tensor(x, name="x") if dtype_util.as_numpy_dtype(x.dtype) not in [np.float32, np.float64]: raise TypeError( "x.dtype=%s is not handled, see docstring for supported types." % x.dtype) return _ndtr(x)
def _log_survival_function(self, y): low = self._low high = self._high # Recall the promise: # survival_function(y) := P[Y > y] # = 0, if y >= high, # = 1, if y < low, # = P[X > y], otherwise. # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in # between. j = tf.math.ceil(y) # P[X > j], used when low < X < high. result_so_far = self.distribution.log_survival_function(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.zeros_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where( j >= high, dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far) return result_so_far
def _log_prob(self, x, power=None): # The log probability at positive integer points x is log(x^(-power) / Z) # where Z is the normalization constant. For x < 1 and non-integer points, # the log-probability is -inf. # # However, if interpolate_nondiscrete is True, we return the natural # continuous relaxation for x >= 1 which agrees with the log probability at # positive integer points. # # If interpolate_nondiscrete is False and validate_args is True, we check # that the sample point x is in the support. That is, x is equivalent to a # positive integer. power = power if power is not None else tf.convert_to_tensor(self.power) x = tf.cast(x, power.dtype) if self.validate_args and not self.interpolate_nondiscrete: x = distribution_util.embed_check_integer_casting_closed( x, target_dtype=self.dtype, assert_positive=True) log_normalization = tf.math.log(tf.math.zeta(power, 1.)) safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 1.) y = -power * tf.math.log(safe_x) log_unnormalized_prob = tf.where( tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf)) return log_unnormalized_prob - log_normalization
def _get_shape(x, out_type=tf.int32): # Return the shape of a Tensor or a SparseTensor as an np.array if its shape # is known statically. Otherwise return a Tensor representing the shape. if tensorshape_util.is_fully_defined(x.shape): return np.array(tensorshape_util.as_list(x.shape), dtype=dtype_util.as_numpy_dtype(out_type)) return tf.shape(x, out_type=out_type)
def _log_cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) result_so_far = self.distribution.log_cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where( j < low, dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _kl_uniform_uniform(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Uniform. Note that the KL divergence is infinite if the support of `a` is not a subset of the support of `b`. Args: a: instance of a Uniform distribution object. b: instance of a Uniform distribution object. name: (optional) Name to use for created operations. default is "kl_uniform_uniform". Returns: Batchwise KL(a || b) """ with tf.name_scope(name or 'kl_uniform_uniform'): # Consistent with # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 60 # Watch out for the change in conventions--they use 'a' and 'b' to refer to # lower and upper bounds respectively there. dtype = dtype_util.common_dtype([a.low, a.high, b.low, b.high], tf.float32) a_low = tf.convert_to_tensor(a.low) b_low = tf.convert_to_tensor(b.low) a_high = tf.convert_to_tensor(a.high) b_high = tf.convert_to_tensor(b.high) return tf.where( (b_low <= a_low) & (a_high <= b_high), tf.math.log(b_high - b_low) - tf.math.log(a_high - a_low), dtype_util.as_numpy_dtype(dtype)(np.inf))
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def log1psquare(x, name=None): """Numerically stable calculation of `log(1 + x**2)` for small or large `|x|`. For sufficiently large `x` we use the following observation: ```none log(1 + x**2) = 2 log(|x|) + log(1 + 1 / x**2) --> 2 log(|x|) as x --> inf ``` Numerically, `log(1 + 1 / x**2)` is `0` when `1 / x**2` is small relative to machine epsilon. Args: x: Float `Tensor` input. name: Python string indicating the name of the TensorFlow operation. Default value: `'log1psquare'`. Returns: log1psq: Float `Tensor` representing `log(1. + x**2.)`. """ with tf.name_scope(name or 'log1psquare'): x = tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x') dtype = dtype_util.as_numpy_dtype(x.dtype) eps = np.finfo(dtype).eps.astype(np.float64) is_large = tf.abs(x) > (eps**-0.5).astype(dtype) # Mask out small x's so the gradient correctly propagates. abs_large_x = tf.where(is_large, tf.abs(x), tf.ones([], x.dtype)) return tf.where(is_large, 2. * tf.math.log(abs_large_x), tf.math.log1p(tf.square(x)))
def _kl_pareto_pareto(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Pareto. Args: a: instance of a Pareto distribution object. b: instance of a Pareto distribution object. name: (optional) Name to use for created operations. default is 'kl_pareto_pareto'. Returns: Batchwise KL(a || b) """ with tf.name_scope(name or 'kl_pareto_pareto'): # Consistent with # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 55 # Terminology is different from source to source for Pareto distributions. # The 'concentration' parameter corresponds to 'a' in that source, and the # 'scale' parameter corresponds to 'm'. a_scale = tf.convert_to_tensor(a.scale) b_scale = tf.convert_to_tensor(b.scale) a_concentration = tf.convert_to_tensor(a.concentration) b_concentration = tf.convert_to_tensor(b.concentration) return tf.where( a_scale >= b_scale, (b_concentration * (tf.math.log(a_scale) - tf.math.log(b_scale)) + tf.math.log(a_concentration) - tf.math.log(b_concentration) + b_concentration / a_concentration - 1.), dtype_util.as_numpy_dtype(a.dtype)(np.inf))
def _extend_support(self, x, scale, f, alt): """Returns `f(x)` if x is in the support, and `alt` otherwise. Given `f` which is defined on the support of this distribution (e.g. x > scale), extend the function definition to the real line by defining `f(x) = alt` for `x < scale`. Args: x: Floating-point Tensor to evaluate `f` at. scale: Floating-point Tensor by which to verify `x` validity. f: Lambda that takes in a tensor and returns a tensor. This represents the function who we want to extend the domain of definition. alt: Python or numpy literal representing the value to use for extending the domain. Returns: Tensor representing an extension of `f(x)`. """ if self.validate_args: return f(x) scale = tf.convert_to_tensor(self.scale) if scale is None else scale is_invalid = x < scale # We need to do this to ensure gradients are sound. y = f(tf.where(is_invalid, scale, x)) if alt == 0.: alt = tf.zeros([], dtype=y.dtype) elif alt == 1.: alt = tf.ones([], dtype=y.dtype) else: alt = dtype_util.as_numpy_dtype(self.dtype)(alt) return tf.where(is_invalid, alt, y)
def _variance(self): concentration = tf.convert_to_tensor(self.concentration) valid_variance = (self.scale**2 * concentration / ((concentration - 1.)**2 * (concentration - 2.))) return tf.where(concentration > 2., valid_variance, dtype_util.as_numpy_dtype(self.dtype)(np.inf))
def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank; the number of non-zero SVD singular values. Arguments: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. tol: Threshold below which the singular value is counted as 'zero'. Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'matrix_rank'. Returns: matrix_rank: (Batch of) `int32` scalars representing the number of non-zero singular values. """ with tf.name_scope(name or 'matrix_rank'): a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) s = tf.linalg.svd(a, compute_uv=False) if tol is None: if tensorshape_util.is_fully_defined(a.shape[-2:]): m = np.max(a.shape[-2:].as_list()) else: m = tf.reduce_max(tf.shape(a)[-2:]) eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps tol = (eps * tf.cast(m, a.dtype) * tf.reduce_max(s, axis=-1, keepdims=True)) return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
def _entropy(self): logits, probs = self._logits_and_probs_no_checks() if not self.validate_args: assertions = [] else: assertions = [ assert_util.assert_less( probs, dtype_util.as_numpy_dtype(self.dtype)(1.), message= 'Entropy is undefined when logits = inf or probs = 1.') ] with tf.control_dependencies(assertions): # Claim: entropy(p) = softplus(s)/p - s # where s=logits and p=probs. # # Proof: # # entropy(p) # := -[(1-p)log(1-p) + plog(p)]/p # = -[log(1-p) + plog(p/(1-p))]/p # = -[-softplus(s) + ps]/p # = softplus(s)/p - s # # since, # log[1-sigmoid(s)] # = log[1/(1+exp(s)] # = -log[1+exp(s)] # = -softplus(s) # # using the fact that, # 1-sigmoid(s) = sigmoid(-s) = 1/(1+exp(s)) return tf.math.softplus(logits) / probs - logits
def _log_unnormalized_prob(self, x, log_rate): # The log-probability at negative points is always -inf. # Catch such x's and set the output value accordingly. safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.) y = safe_x * log_rate - tf.math.lgamma(1. + safe_x) return tf.where(tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='Gumbel'): """Construct Gumbel distributions with location and scale `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). Args: loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Gumbel'`. Raises: TypeError: if loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) loc = tensor_util.convert_nonref_to_tensor( loc, name='loc', dtype=dtype) scale = tensor_util.convert_nonref_to_tensor( scale, name='scale', dtype=dtype) dtype_util.assert_same_float_dtype([loc, scale]) # Positive scale is asserted by the incorporated Gumbel bijector. self._gumbel_bijector = gumbel_bijector.Gumbel( loc=loc, scale=scale, validate_args=validate_args) # Because the uniform sampler generates samples in `[0, 1)` this would # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny` # because it is the smallest, positive, 'normal' number. super(Gumbel, self).__init__( distribution=uniform.Uniform( low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny, high=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats), # The Gumbel bijector encodes the quantile function as the forward, # and hence needs to be inverted. bijector=invert_bijector.Invert(self._gumbel_bijector), batch_shape=distribution_util.get_broadcast_shape(loc, scale), parameters=parameters, name=name)
def _log_prob(self, x): log_prob = -(0.5 * ( (x - self.loc) / self.scale)**2 + 0.5 * np.log(2. * np.pi) + tf.math.log(self.scale * self._normalizer)) # p(x) is 0 outside the bounds. bounded_log_prob = tf.where( (x > self._high) | (x < self._low), dtype_util.as_numpy_dtype(x.dtype)(-np.inf), log_prob) return bounded_log_prob
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_none_equal( t, dtype_util.as_numpy_dtype(t.dtype)(0.), message="All elements must be non-zero.") ]
def _log_prob(self, x): log_rate = self._log_rate_parameter_no_checks() log_probs = (self._log_unnormalized_prob(x, log_rate) - self._log_normalization(log_rate)) if not self.interpolate_nondiscrete: # Ensure the gradient wrt `rate` is zero at non-integer points. log_probs = tf.where( tf.math.is_inf(log_probs), dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf), log_probs) return log_probs
def _log_cdf(self, x): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) with tf.control_dependencies(self._maybe_assert_valid_sample(x, loc)): safe_x = self._get_safe_input(x, loc=loc, scale=scale) log_cdf = np.log(2 / np.pi) + tf.math.log( tf.atan((safe_x - loc) / scale)) return tf.where(x < loc, dtype_util.as_numpy_dtype(self.dtype)(-np.inf), log_cdf)
def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u]
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if (self.hinge_softness is not None and is_init != tensor_util.is_ref(self.hinge_softness)): assertions.append(assert_util.assert_none_equal( dtype_util.as_numpy_dtype(self._hinge_softness.dtype)(0), self.hinge_softness, message='Argument `hinge_softness` must be non-zero.')) return assertions
def softplus_inverse(x, name=None): """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). Mathematically this op is equivalent to: ```none softplus_inverse = log(exp(x) - 1.) ``` Args: x: `Tensor`. Non-negative (not enforced), floating-point. name: A name for the operation (optional). Returns: `Tensor`. Has the same type/shape as input `x`. """ with tf.name_scope(name or 'softplus_inverse'): x = tf.convert_to_tensor(x, name='x') # We begin by deriving a more numerically stable softplus_inverse: # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). # ==> exp{x} = 1 + exp{y} (1) # ==> y = Log[exp{x} - 1] (2) # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] # = Log[1 - exp{-x}] + x (3) # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. # # In addition to the numerically stable derivation above, we clamp # small/large values to be congruent with the logic in: # tensorflow/core/kernels/softplus_op.h # # Finally, we set the input to one whenever the input is too large or too # small. This ensures that no unchosen codepath is +/- inf. This is # necessary to ensure the gradient doesn't get NaNs. Recall that the # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful # to overwrite `x` with ones only when we will never actually use this # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. threshold = np.log(np.finfo(dtype_util.as_numpy_dtype( x.dtype)).eps) + 2. is_too_small = x < np.exp(threshold) is_too_large = x > -threshold too_small_value = tf.math.log(x) too_large_value = x # This `where` will ultimately be a NOP because we won't select this # codepath whenever we used the surrogate `ones_like`. x = tf.where(is_too_small | is_too_large, tf.ones([], x.dtype), x) y = x + tf.math.log(-tf.math.expm1(-x)) # == log(expm1(x)) return tf.where(is_too_small, too_small_value, tf.where(is_too_large, too_large_value, y))
def _validate_correlationness(self, x): if not self.validate_args or self.input_output_cholesky: return x checks = [ assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.'), assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.'), assert_util.assert_near(tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.'), assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric') ] with tf.control_dependencies(checks): return tf.identity(x)
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_ref(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def _mean(self): mean = tf.broadcast_to(self.loc, self._sample_shape()) if self.allow_nan_stats: return tf.where(self.df[..., tf.newaxis] > 1., mean, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), self.df, message='Mean not defined for components of df <= 1.'), ]): return tf.identity(mean)
def _variance(self): df = tf.convert_to_tensor(self.df) scale = tf.convert_to_tensor(self.scale) # We need to put the tf.where inside the outer tf.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) # Abs(scale) superfluous. var = (tf.ones(self._batch_shape_tensor(df=df, scale=scale), dtype=self.dtype) * tf.square(scale) * df / denom) # When 1 < df <= 2, variance is infinite. result_where_defined = tf.where( df > 2., var, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=self.dtype), df, message='variance not defined for components of df <= 1'), ], result_where_defined)