def _log_prob(self, x): logits = self._logits_parameter_no_checks() event_size = self._event_size(logits) x = tf.cast(x, logits.dtype) x = self._maybe_assert_valid_sample(x, dtype=logits.dtype) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(logits), tf.shape(x)) logits = tf.broadcast_to(logits, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1)) logits_2d = tf.reshape(logits, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) ret = -tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(x_2d), logits=logits_2d) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, logits_shape) return ret
def _batch_shape_tensor(self, concentration=None, total_count=None): if concentration is None: concentration = tf.convert_to_tensor(self._concentration) if total_count is None: total_count = tf.convert_to_tensor(self._total_count) return tf.broadcast_dynamic_shape( tf.shape(total_count[..., tf.newaxis]), tf.shape(concentration))[:-1]
def _assertions(self, t): if self.validate_args: return [] is_matrix = assert_util.assert_rank_at_least(t, 2) is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1]) is_positive_definite = assert_util.assert_positive( tf.linalg.diag_part(t), message="Input must be positive definite.") return [is_matrix, is_square, is_positive_definite]
def validate_equal_last_dim(tensor_a, tensor_b, message): event_size_a = tf.compat.dimension_value(tensor_a.shape[-1]) event_size_b = tf.compat.dimension_value(tensor_b.shape[-1]) if event_size_a is not None and event_size_b is not None: if event_size_a != event_size_b: raise ValueError(message) elif validate_args: return assert_util.assert_equal(tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)
def _log_moment(self, n, concentration1=None, concentration0=None): """Compute the n'th (uncentered) moment.""" concentration0 = tf.convert_to_tensor( self.concentration0) if concentration0 is None else concentration0 concentration1 = tf.convert_to_tensor( self.concentration1) if concentration1 is None else concentration1 total_concentration = concentration1 + concentration0 expanded_concentration1 = tf.broadcast_to( concentration1, tf.shape(total_concentration)) expanded_concentration0 = tf.broadcast_to( concentration0, tf.shape(total_concentration)) beta_arg0 = 1 + n / expanded_concentration1 beta_arg = tf.stack([beta_arg0, expanded_concentration0], -1) return tf.math.log(expanded_concentration0) + tf.math.lbeta(beta_arg)
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat( [self._batch_shape_tensor(samples), self._event_shape_tensor(samples)], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): shape = tf.shape(x) return tf.linalg.triangular_solve( x, tf.eye(shape[-1], batch_shape=shape[:-2], dtype=x.dtype), lower=True)
def _sample_n(self, n, seed=None): n_draws = tf.cast(self.total_count, dtype=tf.int32) logits = self._logits_parameter_no_checks() k = tf.compat.dimension_value(logits.shape[-1]) if k is None: k = tf.shape(logits)[-1] return draw_sample(n, k, logits, n_draws, self.dtype, seed)
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector(tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank; the number of non-zero SVD singular values. Arguments: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. tol: Threshold below which the singular value is counted as 'zero'. Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'matrix_rank'. Returns: matrix_rank: (Batch of) `int32` scalars representing the number of non-zero singular values. """ with tf.name_scope(name or 'matrix_rank'): a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) s = tf.linalg.svd(a, compute_uv=False) if tol is None: if tensorshape_util.is_fully_defined(a.shape[-2:]): m = np.max(a.shape[-2:].as_list()) else: m = tf.reduce_max(tf.shape(a)[-2:]) eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps tol = (eps * tf.cast(m, a.dtype) * tf.reduce_max(s, axis=-1, keepdims=True)) return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
def lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def _inverse(self, y): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(y), self._event_shape_out, self._event_shape_in, self.validate_args) x = tf.reshape(y, output_shape) tensorshape_util.set_shape(x, output_tensorshape) return x
def _expand_mix_distribution_probs(self): p = self.mixture_distribution.probs_parameter() # [B, deg] deg = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(p.shape, 1)[-1]) if deg is None: deg = tf.shape(p)[-1] event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_ndims = tf.shape(self.event_shape_tensor())[0] expand_shape = tf.concat([ self.mixture_distribution.batch_shape_tensor(), tf.ones([event_ndims], dtype=tf.int32), [deg], ], axis=0) return tf.reshape(p, shape=expand_shape)
def _call_sample_n(self, sample_shape, seed, name, **kwargs): # We override `_call_sample_n` rather than `_sample_n` so we can ensure that # the result of `self.bijector.forward` is not modified (and thus caching # works). with self._name_and_control_scope(name): sample_shape = tf.convert_to_tensor(sample_shape, dtype=tf.int32, name="sample_shape") sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, "sample_shape") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn( kwargs) # First, generate samples. We will possibly generate extra samples in the # event that we need to reinterpret the samples as part of the # event_shape. x = self._sample_n(n, seed, **distribution_kwargs) # Next, we reshape `x` into its final form. We do this prior to the call # to the bijector to ensure that the bijector caching works. batch_event_shape = tf.shape(x)[1:] final_shape = tf.concat([sample_shape, batch_event_shape], 0) x = tf.reshape(x, final_shape) # Finally, we apply the bijector's forward transformation. For caching to # work, it is imperative that this is the last modification to the # returned result. y = self.bijector.forward(x, **bijector_kwargs) y = self._set_sample_static_shape(y, sample_shape) return y
def _forward(self, x): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(x), self._event_shape_in, self._event_shape_out, self.validate_args) y = tf.reshape(x, output_shape) tensorshape_util.set_shape(y, output_tensorshape) return y
def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError( "Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value(param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def _get_shape(x, out_type=tf.int32): # Return the shape of a Tensor or a SparseTensor as an np.array if its shape # is known statically. Otherwise return a Tensor representing the shape. if tensorshape_util.is_fully_defined(x.shape): return np.array(tensorshape_util.as_list(x.shape), dtype=dtype_util.as_numpy_dtype(out_type)) return tf.shape(x, out_type=out_type)
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat( [ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _entropy(self): concentration = tf.convert_to_tensor(self.concentration) k = tf.cast(tf.shape(concentration)[-1], self.dtype) total_concentration = tf.reduce_sum(concentration, axis=-1) return (tf.math.lbeta(concentration) + ((total_concentration - k) * tf.math.digamma(total_concentration)) - tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration), axis=-1))
def _event_size(self, param=None): if param is None: param = self._logits if self._logits is not None else self._probs if param.shape is not None: event_size = tf.compat.dimension_value(param.shape[-1]) if event_size is not None: return event_size return tf.shape(param)[-1]
def _shape(input, out_type=tf.int32, name=None): # pylint: disable=redefined-builtin if not hasattr(input, 'shape'): x = np.array(input) input = tf.convert_to_tensor(input) if x.dtype is np.object else x input_shape = tf.TensorShape(input.shape) if tensorshape_util.is_fully_defined(input.shape): return np.array(tensorshape_util.as_list(input_shape)).astype( _numpy_dtype(out_type)) return tf.shape(input, out_type=out_type, name=name)
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast(tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _num_categories(self, x=None): """Scalar `int32` tensor: the number of categories.""" with tf.name_scope('num_categories'): if x is None: x = self._probs if self._logits is None else self._logits num_categories = tf.compat.dimension_value(x.shape[-1]) if num_categories is not None: return num_categories return tf.shape(x)[-1]
def _sample_n(self, n, seed=None): scale = tf.convert_to_tensor(self.scale) shape = tf.concat([[n], tf.shape(scale)], 0) sampled = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) return tf.abs(sampled * scale)
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): x_shape = tf.shape(x) identity_matrix = tf.eye(x_shape[-1], batch_shape=x_shape[:-2], dtype=dtype_util.base_dtype(x.dtype)) # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`. y = tf.linalg.triangular_solve(x, identity_matrix) y = tf.matmul(y, y, adjoint_a=True) return tf.linalg.cholesky(y)
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(x), self._batch_shape_tensor(low=low, high=high)) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _pad_sample_dims(self, x): with tf.name_scope("pad_sample_dims"): ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank( x.shape) is not None else tf.rank(x) shape = tf.shape(x) d = ndims - self._event_ndims x = tf.reshape(x, shape=tf.concat([shape[:d], [1], shape[d:]], axis=0)) return x
def _inverse(self, y): # As specified in the Stan reference manual, the procedure is as follows: # N = y.shape[-1] # z_k = y_k / (1 - sum_{i=1 to k-1} y_i) # x_k = logit(z_k) - log(1 / (N - k)) offset = tf.math.log( tf.cast( tf.range(tf.shape(y)[-1] - 1, 0, delta=-1), dtype=dtype_util.base_dtype(y.dtype))) z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True)) return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset