def log1psquare(x, name=None): """Numerically stable calculation of `log(1 + x**2)` for small or large `|x|`. For sufficiently large `x` we use the following observation: ```none log(1 + x**2) = 2 log(|x|) + log(1 + 1 / x**2) --> 2 log(|x|) as x --> inf ``` Numerically, `log(1 + 1 / x**2)` is `0` when `1 / x**2` is small relative to machine epsilon. Args: x: Float `Tensor` input. name: Python string indicating the name of the TensorFlow operation. Default value: `'log1psquare'`. Returns: log1psq: Float `Tensor` representing `log(1. + x**2.)`. """ with tf.name_scope(name or 'log1psquare'): x = tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x') dtype = dtype_util.as_numpy_dtype(x.dtype) eps = np.finfo(dtype).eps.astype(np.float64) is_large = tf.abs(x) > (eps**-0.5).astype(dtype) # Mask out small x's so the gradient correctly propagates. abs_large_x = tf.where(is_large, tf.abs(x), tf.ones([], x.dtype)) return tf.where(is_large, 2. * tf.math.log(abs_large_x), tf.math.log1p(tf.square(x)))
def _survival_function(self, y): low = self._low high = self._high # Recall the promise: # survival_function(y) := P[Y > y] # = 0, if y >= high, # = 1, if y < low, # = P[X > y], otherwise. # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in # between. j = tf.math.ceil(y) # P[X > j], used when low < X < high. result_so_far = self.distribution.survival_function(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.ones_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _extend_support(self, x, scale, f, alt): """Returns `f(x)` if x is in the support, and `alt` otherwise. Given `f` which is defined on the support of this distribution (e.g. x > scale), extend the function definition to the real line by defining `f(x) = alt` for `x < scale`. Args: x: Floating-point Tensor to evaluate `f` at. scale: Floating-point Tensor by which to verify `x` validity. f: Lambda that takes in a tensor and returns a tensor. This represents the function who we want to extend the domain of definition. alt: Python or numpy literal representing the value to use for extending the domain. Returns: Tensor representing an extension of `f(x)`. """ if self.validate_args: return f(x) scale = tf.convert_to_tensor(self.scale) if scale is None else scale is_invalid = x < scale # We need to do this to ensure gradients are sound. y = f(tf.where(is_invalid, scale, x)) if alt == 0.: alt = tf.zeros([], dtype=y.dtype) elif alt == 1.: alt = tf.ones([], dtype=y.dtype) else: alt = dtype_util.as_numpy_dtype(self.dtype)(alt) return tf.where(is_invalid, alt, y)
def _std_var_helper(self, statistic, statistic_name, statistic_ndims, df_factor_fn): """Helper to compute stddev, covariance and variance.""" df = tf.reshape( self.df, tf.concat([ tf.shape(self.df), tf.ones([statistic_ndims], dtype=tf.int32) ], -1)) # We need to put the tf.where inside the outer tf1.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) statistic = statistic * df_factor_fn(df / denom) # When 1 < df <= 2, stddev/variance are infinite. result_where_defined = tf.where( df > 2., statistic, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), df, message='{} not defined for components of df <= 1.'. format(statistic_name.capitalize())), ]): return tf.identity(result_where_defined)
def _cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) # P[X <= j], used when low < X < high. result_so_far = self.distribution.cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.zeros_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.ones_like(result_so_far), result_so_far) return result_so_far
def _log_cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) result_so_far = self.distribution.log_cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where( j < low, dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = tf.shape(input=param) insert_ones = tf.ones( [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)], dtype=param_shape.dtype) new_param_shape = tf.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = tf.where(is_broadcast, 0, start) if stop is not None: stop = tf.where(is_broadcast, 1, stop) if step is not None: step = tf.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(tf.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(param_slices)
def _ndtr(x): """Implements ndtr core logic.""" half_sqrt_2 = tf.constant(0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2") w = x * half_sqrt_2 z = tf.abs(w) y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w), tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z))) return 0.5 * y
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(x), self._batch_shape_tensor(low=low, high=high)) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) return tf.where( tf.math.is_nan(x), x, tf.where( # This > is only sound for continuous uniform (x < low) | (x > high), tf.zeros_like(x), tf.ones_like(x) / self._range(low=low, high=high)))
def _log_prob(self, x): scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) z = self._z(x, scale, concentration) eq_zero = tf.equal(concentration, 0) # Concentration = 0 ==> Exponential. nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype), concentration) where_nonzero = (1 / nonzero_conc + 1) * tf.math.log1p( nonzero_conc * z) return -tf.math.log(scale) - tf.where(eq_zero, z, where_nonzero)
def _log_cdf(self, x): scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) z = self._z(x, scale, concentration) eq_zero = tf.equal(concentration, 0) # Concentration = 0 ==> Exponential. nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype), concentration) where_nonzero = tf.math.log1p(-(1 + nonzero_conc * z)**(-1 / nonzero_conc)) where_zero = tf.math.log1p(-tf.exp(-z)) return tf.where(eq_zero, where_zero, where_nonzero)
def softplus_inverse(x, name=None): """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). Mathematically this op is equivalent to: ```none softplus_inverse = log(exp(x) - 1.) ``` Args: x: `Tensor`. Non-negative (not enforced), floating-point. name: A name for the operation (optional). Returns: `Tensor`. Has the same type/shape as input `x`. """ with tf.name_scope(name or 'softplus_inverse'): x = tf.convert_to_tensor(x, name='x') # We begin by deriving a more numerically stable softplus_inverse: # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). # ==> exp{x} = 1 + exp{y} (1) # ==> y = Log[exp{x} - 1] (2) # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] # = Log[1 - exp{-x}] + x (3) # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. # # In addition to the numerically stable derivation above, we clamp # small/large values to be congruent with the logic in: # tensorflow/core/kernels/softplus_op.h # # Finally, we set the input to one whenever the input is too large or too # small. This ensures that no unchosen codepath is +/- inf. This is # necessary to ensure the gradient doesn't get NaNs. Recall that the # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful # to overwrite `x` with ones only when we will never actually use this # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. threshold = np.log(np.finfo(dtype_util.as_numpy_dtype( x.dtype)).eps) + 2. is_too_small = x < np.exp(threshold) is_too_large = x > -threshold too_small_value = tf.math.log(x) too_large_value = x # This `where` will ultimately be a NOP because we won't select this # codepath whenever we used the surrogate `ones_like`. x = tf.where(is_too_small | is_too_large, tf.ones([], x.dtype), x) y = x + tf.math.log(-tf.math.expm1(-x)) # == log(expm1(x)) return tf.where(is_too_small, too_small_value, tf.where(is_too_large, too_large_value, y))
def _mean(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError('event shape must be statically known for _bessel_ive') safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_mean = self.mean_direction * ( _bessel_ive(event_dim / 2, safe_conc) / _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis] return tf.where( self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean), safe_mean, tf.zeros_like(safe_mean))
def _log_prob(self, x, power=None): # The log probability at positive integer points x is log(x^(-power) / Z) # where Z is the normalization constant. For x < 1 and non-integer points, # the log-probability is -inf. # # However, if interpolate_nondiscrete is True, we return the natural # continuous relaxation for x >= 1 which agrees with the log probability at # positive integer points. # # If interpolate_nondiscrete is False and validate_args is True, we check # that the sample point x is in the support. That is, x is equivalent to a # positive integer. power = power if power is not None else tf.convert_to_tensor(self.power) x = tf.cast(x, power.dtype) if self.validate_args and not self.interpolate_nondiscrete: x = distribution_util.embed_check_integer_casting_closed( x, target_dtype=self.dtype, assert_positive=True) log_normalization = tf.math.log(tf.math.zeta(power, 1.)) safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 1.) y = -power * tf.math.log(safe_x) log_unnormalized_prob = tf.where( tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf)) return log_unnormalized_prob - log_normalization
def _mean(self): concentration = tf.convert_to_tensor(self.concentration) lim = tf.ones([], dtype=self.dtype) valid = concentration < lim safe_conc = tf.where(valid, concentration, tf.constant(.5, self.dtype)) result = lambda: self.loc + self.scale / (1 - safe_conc) if self.allow_nan_stats: return tf.where(valid, result(), tf.constant(float('nan'), self.dtype)) with tf.control_dependencies([ assert_util.assert_less( concentration, lim, message='`mean` is undefined when `concentration >= 1`') ]): return result()
def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform( shape, minval=minval_u, maxval=maxval_u, dtype=power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp( self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k]
def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _variance(self): concentration = tf.convert_to_tensor(self.concentration) valid_variance = (self.scale**2 * concentration / ((concentration - 1.)**2 * (concentration - 2.))) return tf.where(concentration > 2., valid_variance, dtype_util.as_numpy_dtype(self.dtype)(np.inf))
def _sample_n(self, n, seed=None): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) stream = SeedStream(seed, salt='triangular') shape = tf.concat( [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)], axis=0) samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=stream()) # We use Inverse CDF sampling here. Because the CDF is a quadratic function, # we must use sqrts here. interval_length = high - low return tf.where( # Note the CDF on the left side of the peak is # (x - low) ** 2 / ((high - low) * (peak - low)). # If we plug in peak for x, we get that the CDF at the peak # is (peak - low) / (high - low). Because of this we decide # which part of the piecewise CDF we should use based on the cdf samples # we drew. samples < (peak - low) / interval_length, # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)). low + tf.sqrt(samples * interval_length * (peak - low)), # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak)) high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
def _kl_pareto_pareto(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Pareto. Args: a: instance of a Pareto distribution object. b: instance of a Pareto distribution object. name: (optional) Name to use for created operations. default is 'kl_pareto_pareto'. Returns: Batchwise KL(a || b) """ with tf.name_scope(name or 'kl_pareto_pareto'): # Consistent with # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 55 # Terminology is different from source to source for Pareto distributions. # The 'concentration' parameter corresponds to 'a' in that source, and the # 'scale' parameter corresponds to 'm'. a_scale = tf.convert_to_tensor(a.scale) b_scale = tf.convert_to_tensor(b.scale) a_concentration = tf.convert_to_tensor(a.concentration) b_concentration = tf.convert_to_tensor(b.concentration) return tf.where( a_scale >= b_scale, (b_concentration * (tf.math.log(a_scale) - tf.math.log(b_scale)) + tf.math.log(a_concentration) - tf.math.log(b_concentration) + b_concentration / a_concentration - 1.), dtype_util.as_numpy_dtype(a.dtype)(np.inf))
def _inverse_log_det_jacobian(self, y, use_saved_statistics=False): if not self.batchnorm.built: # Create variables. self.batchnorm.build(y.shape) event_dims = self.batchnorm.axis reduction_axes = [i for i in range(len(y.shape)) if i not in event_dims] # At training-time, ildj is computed from the mean and log-variance across # the current minibatch. # We use multiplication instead of tf.where() to get easier broadcasting. log_variance = tf.math.log( tf.where( tf.logical_or(use_saved_statistics, tf.logical_not(self._training)), self.batchnorm.moving_variance, tf.nn.moments(x=y, axes=reduction_axes, keepdims=True)[1]) + self.batchnorm.epsilon) # TODO(b/137216713): determine whether it's unsafe for the reduce_sums below # to happen across all axes. # `gamma` and `log Var(y)` reductions over event_dims. # Log(total change in area from gamma term). log_total_gamma = tf.reduce_sum(tf.math.log(self.batchnorm.gamma)) # Log(total change in area from log-variance term). log_total_variance = tf.reduce_sum(log_variance) # The ildj is scalar, as it does not depend on the values of x and are # constant across minibatch elements. return log_total_gamma - 0.5 * log_total_variance
def _kl_uniform_uniform(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Uniform. Note that the KL divergence is infinite if the support of `a` is not a subset of the support of `b`. Args: a: instance of a Uniform distribution object. b: instance of a Uniform distribution object. name: (optional) Name to use for created operations. default is "kl_uniform_uniform". Returns: Batchwise KL(a || b) """ with tf.name_scope(name or 'kl_uniform_uniform'): # Consistent with # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 60 # Watch out for the change in conventions--they use 'a' and 'b' to refer to # lower and upper bounds respectively there. dtype = dtype_util.common_dtype([a.low, a.high, b.low, b.high], tf.float32) a_low = tf.convert_to_tensor(a.low) b_low = tf.convert_to_tensor(b.low) a_high = tf.convert_to_tensor(a.high) b_high = tf.convert_to_tensor(b.high) return tf.where( (b_low <= a_low) & (a_high <= b_high), tf.math.log(b_high - b_low) - tf.math.log(a_high - a_low), dtype_util.as_numpy_dtype(dtype)(np.inf))
def _log_normalization(self): """Computes the log-normalizer of the distribution.""" event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError('vMF _log_normalizer currently only supports ' 'statically known event shape') safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) - (event_dim / 2) * np.log(2 * np.pi) - tf.math.log(_bessel_ive(event_dim / 2 - 1, safe_conc)) - tf.abs(safe_conc)) log_nsphere_surface_area = ( np.log(2.) + (event_dim / 2) * np.log(np.pi) - tf.math.lgamma(tf.cast(event_dim / 2, self.dtype))) return tf.where(self.concentration > 0, -safe_lognorm, log_nsphere_surface_area)
def _log_unnormalized_prob(self, x, log_rate): # The log-probability at negative points is always -inf. # Catch such x's and set the output value accordingly. safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.) y = safe_x * log_rate - tf.math.lgamma(1. + safe_x) return tf.where(tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))
def _cdf(self, x): df = tf.convert_to_tensor(self.df) # Take Abs(scale) to make subsequent where work correctly. y = (x - self.loc) / tf.abs(self.scale) x_t = df / (y**2. + df) neg_cdf = 0.5 * tf.math.betainc( 0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t) return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
def _pick_scalar_condition(pred, cond_true, cond_false): """Convenience function which chooses the condition based on the predicate.""" # Note: This function is only valid if all of pred, cond_true, and cond_false # are scalars. This means its semantics are arguably more like tf.cond than # tf.where even though we use tf.where to implement it. pred_ = tf.get_static_value(tf.convert_to_tensor(pred)) if pred_ is None: return tf.where(pred, cond_true, cond_false) return cond_true if pred_ else cond_false
def _log_prob(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): probs = self._probs_parameter_no_checks() if not self.validate_args: # For consistency with cdf, we take the floor. x = tf.floor(x) safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs), probs) return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)