def _inverse(self, y): map_values = tf.convert_to_tensor(self.map_values) flat_y = tf.reshape(y, shape=[-1]) # Search for the indices of map_values that are closest to flat_y. # Since map_values is strictly increasing, the closest is either the # first one that is strictly greater than flat_y, or the one before it. upper_candidates = tf.minimum( tf.size(map_values) - 1, tf.searchsorted(map_values, values=flat_y, side='right')) lower_candidates = tf.maximum(0, upper_candidates - 1) candidates = tf.stack([lower_candidates, upper_candidates], axis=-1) lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates)) upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates)) if self.validate_args: with tf.control_dependencies([ assert_util.assert_near(tf.minimum(lower_cand_diff, upper_cand_diff), 0, message='inverse value not found') ]): candidates = tf.identity(candidates) candidate_selector = tf.stack([ tf.range(tf.size(flat_y), dtype=tf.int32), tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32) ], axis=-1) return tf.reshape(tf.gather_nd(candidates, candidate_selector), shape=y.shape)
def _inverse(self, y): with tf.control_dependencies(self._maybe_assert_valid_y(y)): if self.power == 0.: return tf.math.log(y) # If large y accuracy is an issue, consider using: # (y**self.power - 1.) / self.power when y >> 1. return tf.math.expm1(tf.math.log(y) * self.power) / self.power
def vector_size_to_square_matrix_size(d, validate_args, name=None): """Convert a vector size to a matrix size.""" if isinstance(d, (float, int, np.generic, np.ndarray)): n = (-1 + np.sqrt(1 + 8 * d)) / 2. if float(int(n)) != n: raise ValueError( 'Vector length {} is not a triangular number.'.format(d)) return int(n) else: with tf.name_scope(name or 'vector_size_to_square_matrix_size') as name: n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2. if validate_args: with tf.control_dependencies([ assert_util.assert_equal( tf.cast(tf.cast(n, dtype=tf.int32), dtype=tf.float32), n, data=[ 'Vector length is not a triangular number: ', d ], message='Vector length is not a triangular number') ]): n = tf.identity(n) return tf.cast(n, d.dtype)
def _forward(self, x): with tf.control_dependencies(self._maybe_assert_valid_x(x)): if self.power == 0.: return tf.exp(x) # If large x accuracy is an issue, consider using: # (1. + x * self.power)**(1. / self.power) when x >> 1. return tf.exp(tf.math.log1p(x * self.power) / self.power)
def _batch_shape_tensor(self): with tf.control_dependencies(self._runtime_assertions): return tf.broadcast_dynamic_shape( self._initial_distribution.batch_shape_tensor(), tf.broadcast_dynamic_shape( self._transition_distribution.batch_shape_tensor()[:-1], self._observation_distribution.batch_shape_tensor()[:-1]))
def _inverse_log_det_jacobian(self, y): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. with tf.control_dependencies(self._assertions(y)): zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype)) return zero, zero
def _call_and_reshape_output(self, fn, event_shape_list=None, static_event_shape_list=None, extra_kwargs=None): """Calls `fn` and appropriately reshapes its output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn`, `event_shape_list`, `static_event_shape_list` and/or # `extra_kwargs` as keys. with tf.control_dependencies(self._runtime_assertions): if event_shape_list is None: event_shape_list = [self._event_shape_tensor()] if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = tf.concat([self._batch_shape_unexpanded] + event_shape_list, axis=0) result = tf.reshape( fn(**extra_kwargs) if extra_kwargs else fn(), new_shape) if (tensorshape_util.rank(self.batch_shape) is not None and tensorshape_util.rank(self.event_shape) is not None): event_shape = tf.TensorShape([]) for rss in static_event_shape_list: event_shape = tensorshape_util.concatenate( event_shape, rss) static_shape = tensorshape_util.concatenate( self.batch_shape, event_shape) tensorshape_util.set_shape(result, static_shape) return result
def _std_var_helper(self, statistic, statistic_name, statistic_ndims, df_factor_fn): """Helper to compute stddev, covariance and variance.""" df = tf.reshape( self.df, tf.concat([ tf.shape(self.df), tf.ones([statistic_ndims], dtype=tf.int32) ], -1)) # We need to put the tf.where inside the outer tf1.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) statistic = statistic * df_factor_fn(df / denom) # When 1 < df <= 2, stddev/variance are infinite. result_where_defined = tf.where( df > 2., statistic, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), df, message='{} not defined for components of df <= 1.'. format(statistic_name.capitalize())), ]): return tf.identity(result_where_defined)
def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank; the number of non-zero SVD singular values. Arguments: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. tol: Threshold below which the singular value is counted as 'zero'. Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'matrix_rank'. Returns: matrix_rank: (Batch of) `int32` scalars representing the number of non-zero singular values. """ with tf.name_scope(name or 'matrix_rank'): a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) s = tf.linalg.svd(a, compute_uv=False) if tol is None: if tensorshape_util.is_fully_defined(a.shape[-2:]): m = np.max(a.shape[-2:].as_list()) else: m = tf.reduce_max(tf.shape(a)[-2:]) eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps tol = (eps * tf.cast(m, a.dtype) * tf.reduce_max(s, axis=-1, keepdims=True)) return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
def _entropy(self): logits, probs = self._logits_and_probs_no_checks() if not self.validate_args: assertions = [] else: assertions = [ assert_util.assert_less( probs, dtype_util.as_numpy_dtype(self.dtype)(1.), message= 'Entropy is undefined when logits = inf or probs = 1.') ] with tf.control_dependencies(assertions): # Claim: entropy(p) = softplus(s)/p - s # where s=logits and p=probs. # # Proof: # # entropy(p) # := -[(1-p)log(1-p) + plog(p)]/p # = -[log(1-p) + plog(p/(1-p))]/p # = -[-softplus(s) + ps]/p # = softplus(s)/p - s # # since, # log[1-sigmoid(s)] # = log[1/(1+exp(s)] # = -log[1+exp(s)] # = -softplus(s) # # using the fact that, # 1-sigmoid(s) = sigmoid(-s) = 1/(1+exp(s)) return tf.math.softplus(logits) / probs - logits
def _inverse_log_det_jacobian(self, y): with tf.control_dependencies(self._maybe_assert_valid_y(y)): scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) return (-tf.math.log1p(-y) + tf.math.xlogy(1 / concentration - 1, -tf.math.log1p(-y)) + tf.math.log(scale / concentration))
def _log_prob(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): concentration = tf.convert_to_tensor(self.concentration) loc = tf.convert_to_tensor(self.loc) return (0.5 * (tf.math.log(concentration) - np.log(2. * np.pi) - 3. * tf.math.log(x)) + (-concentration * (x - loc)**2.) / (2. * loc**2. * x))
def _validate_block_sizes(block_sizes, bijectors, validate_args): """Helper to validate block sizes.""" block_sizes_shape = block_sizes.shape if tensorshape_util.is_fully_defined(block_sizes_shape): if (tensorshape_util.rank(block_sizes_shape) != 1 or (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))): raise ValueError( '`block_sizes` must be `None`, or a vector of the same length as ' '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of ' 'length {}'.format(block_sizes_shape, len(bijectors))) return block_sizes elif validate_args: message = ( '`block_sizes` must be `None`, or a vector of the same length ' 'as `bijectors`.') with tf.control_dependencies([ assert_util.assert_equal(tf.size(block_sizes), len(bijectors), message=message), assert_util.assert_equal(tf.rank(block_sizes), 1) ]): return tf.identity(block_sizes) else: return block_sizes
def assert_finite(x, data=None, summarize=None, message=None, name=None): """Assert all elements of `x` are finite. Args: x: Numeric `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_finite". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or lower. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with tf.name_scope(name or 'assert_finite'): x_ = tf.get_static_value(x) if x_ is not None: if ~np.all(np.isfinite(x_)): raise ValueError(message) return x assertion = tf1.assert_equal( tf.math.is_finite(x), tf.ones_like(x, tf.bool), data=data, summarize=summarize, message=message) with tf.control_dependencies([assertion]): return tf.identity(x)
def _log_prob(self, counts): with tf.control_dependencies(self._maybe_assert_valid_sample(counts)): log_p = (tf.math.log(self._probs) if self._logits is None else tf.math.log_softmax(self._logits)) k = tf.convert_to_tensor(self.total_count) return (tf.reduce_sum(counts * log_p, axis=-1) + # log_unnorm_prob tfp_math.log_combinations(k, counts)) # -log_normalization
def _call_reshape_input_output(self, fn, x, extra_kwargs=None): """Calls `fn`, appropriately reshaping its input `x` and output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn` and/or `x` as a key. with tf.control_dependencies(self._runtime_assertions + self._validate_sample_arg(x)): sample_shape, static_sample_shape = self._sample_shape(x) old_shape = tf.concat([ sample_shape, self.distribution.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) x_reshape = tf.reshape(x, old_shape) result = fn(x_reshape, ** extra_kwargs) if extra_kwargs else fn(x_reshape) new_shape = tf.concat([ sample_shape, self._batch_shape_unexpanded, ], axis=0) result = tf.reshape(result, new_shape) if (tensorshape_util.rank(static_sample_shape) is not None and tensorshape_util.rank(self.batch_shape) is not None): new_shape = tensorshape_util.concatenate( static_sample_shape, self.batch_shape) tensorshape_util.set_shape(result, new_shape) return result
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs_parameter(), self, self.mixture_distribution, self._event_ndims) # [B, k, [1]*e] return tf.reduce_sum(probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E]
def _forward_log_det_jacobian(self, x): with tf.control_dependencies(self._maybe_assert_valid_x(x)): scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) return (-(x / scale)**concentration + tf.math.xlogy(concentration - 1, x) + tf.math.log(concentration) - concentration * tf.math.log(scale))
def _cdf(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) shape = self._batch_shape_tensor(concentration1, concentration0) concentration1 = tf.broadcast_to(concentration1, shape) concentration0 = tf.broadcast_to(concentration0, shape) return tf.math.betainc(concentration1, concentration0, x)
def _prob(self, x): if self.validate_args: is_vector_check = assert_util.assert_rank_at_least(x, 1) right_vec_space_check = assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(x), tf.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution" ) with tf.control_dependencies([is_vector_check]): with tf.control_dependencies([right_vec_space_check]): x = tf.identity(x) loc = tf.convert_to_tensor(self.loc) return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc), axis=-1), dtype=self.dtype)
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): shape = tf.shape(x) return tf.linalg.triangular_solve(x, tf.eye(shape[-1], batch_shape=shape[:-2], dtype=x.dtype), lower=True)
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast( tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _forward(self, x): y = x if self.scale is not None: with tf.control_dependencies(self._maybe_collect_assertions( ) if self.validate_args else []): y = self.scale.matvec(y, adjoint=self.adjoint) if self.shift is not None: y = y + self.shift return y
def _log_prob(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): probs = self._probs_parameter_no_checks() if not self.validate_args: # For consistency with cdf, we take the floor. x = tf.floor(x) safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs), probs) return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)
def _cdf(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): probs = self._probs_parameter_no_checks() if not self.validate_args: # Whether or not x is integer-form, the following is well-defined. # However, scipy takes the floor, so we do too. x = tf.floor(x) return tf.where(x < 0., tf.zeros_like(x), -tf.math.expm1( (1. + x) * tf.math.log1p(-probs)))
def _log_prob(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) unnormalized_prob = -(1. + concentration) * tf.math.log(x) - scale / x normalization = (tf.math.lgamma(concentration) - concentration * tf.math.log(scale)) return unnormalized_prob - normalization
def _log_prob(self, x): with tf.control_dependencies(self._runtime_assertions): x = self._pad_sample_dims(x) log_prob_x = self.components_distribution.log_prob(x) # [S, B, k] log_mix_prob = tf.math.log_softmax( self.mixture_distribution.logits_parameter(), axis=-1) # [B, k] return tf.reduce_logsumexp(log_prob_x + log_mix_prob, axis=-1) # [S, B]
def _inverse_event_shape_tensor(self, output_shape_tensor): batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] if self.validate_args: is_square_matrix = assert_util.assert_equal( n, output_shape_tensor[-2], message='Matrix must be square.') with tf.control_dependencies([is_square_matrix]): n = tf.identity(n) d = tf.cast(n * (n + 1) / 2, output_shape_tensor.dtype) return tf.concat([batch_shape, [d]], axis=0)
def _log_prob(self, x): concentration = 0.5 * self.df rate = tf.convert_to_tensor(0.5, dtype=self.dtype) with tf.control_dependencies(self._maybe_assert_valid_sample(x)): log_unnormalized_prob = tf.math.xlogy(concentration - 1., x) - rate * x log_normalization = (tf.math.lgamma(concentration) - concentration * tf.math.log(rate)) return log_unnormalized_prob - log_normalization