def _log_prob(self, x): logits = self._logits_parameter_no_checks() event_size = self._event_size(logits) x = tf.cast(x, logits.dtype) x = self._maybe_assert_valid_sample(x, dtype=logits.dtype) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(logits), tf.shape(x)) logits = tf.broadcast_to(logits, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1)) logits_2d = tf.reshape(logits, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) ret = -tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(x_2d), logits=logits_2d) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, logits_shape) return ret
def _batch_shape_tensor(self): shape_list = [ self.scale.batch_shape_tensor(), tf.shape(self.df), tf.shape(self.loc)[:-1] ] return functools.reduce(tf.broadcast_dynamic_shape, shape_list)
def _assertions(self, t): if self.validate_args: return [] is_matrix = assert_util.assert_rank_at_least(t, 2) is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1]) is_positive_definite = assert_util.assert_positive( tf.linalg.diag_part(t), message="Input must be positive definite.") return [is_matrix, is_square, is_positive_definite]
def _batch_shape_tensor(self, concentration=None, total_count=None): if concentration is None: concentration = tf.convert_to_tensor(self._concentration) if total_count is None: total_count = tf.convert_to_tensor(self._total_count) return tf.broadcast_dynamic_shape( tf.shape(total_count[..., tf.newaxis]), tf.shape(concentration))[:-1]
def validate_equal_last_dim(tensor_a, tensor_b, message): event_size_a = tf.compat.dimension_value(tensor_a.shape[-1]) event_size_b = tf.compat.dimension_value(tensor_b.shape[-1]) if event_size_a is not None and event_size_b is not None: if event_size_a != event_size_b: raise ValueError(message) elif validate_args: return assert_util.assert_equal( tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)
def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank; the number of non-zero SVD singular values. Arguments: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. tol: Threshold below which the singular value is counted as 'zero'. Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'matrix_rank'. Returns: matrix_rank: (Batch of) `int32` scalars representing the number of non-zero singular values. """ with tf.name_scope(name or 'matrix_rank'): a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) s = tf.linalg.svd(a, compute_uv=False) if tol is None: if tensorshape_util.is_fully_defined(a.shape[-2:]): m = np.max(a.shape[-2:].as_list()) else: m = tf.reduce_max(tf.shape(a)[-2:]) eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps tol = (eps * tf.cast(m, a.dtype) * tf.reduce_max(s, axis=-1, keepdims=True)) return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
def lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def _get_shape(x, out_type=tf.int32): # Return the shape of a Tensor or a SparseTensor as an np.array if its shape # is known statically. Otherwise return a Tensor representing the shape. if tensorshape_util.is_fully_defined(x.shape): return np.array(tensorshape_util.as_list(x.shape), dtype=dtype_util.as_numpy_dtype(out_type)) return tf.shape(x, out_type=out_type)
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _call_sample_n(self, sample_shape, seed, name, **kwargs): # We override `_call_sample_n` rather than `_sample_n` so we can ensure that # the result of `self.bijector.forward` is not modified (and thus caching # works). with self._name_and_control_scope(name): sample_shape = tf.convert_to_tensor(sample_shape, dtype=tf.int32, name="sample_shape") sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, "sample_shape") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn( kwargs) # First, generate samples. We will possibly generate extra samples in the # event that we need to reinterpret the samples as part of the # event_shape. x = self._sample_n(n, seed, **distribution_kwargs) # Next, we reshape `x` into its final form. We do this prior to the call # to the bijector to ensure that the bijector caching works. batch_event_shape = tf.shape(x)[1:] final_shape = tf.concat([sample_shape, batch_event_shape], 0) x = tf.reshape(x, final_shape) # Finally, we apply the bijector's forward transformation. For caching to # work, it is imperative that this is the last modification to the # returned result. y = self.bijector.forward(x, **bijector_kwargs) y = self._set_sample_static_shape(y, sample_shape) return y
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat([ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _sample_n(self, n, seed=None): n_draws = tf.cast(self.total_count, dtype=tf.int32) logits = self._logits_parameter_no_checks() k = tf.compat.dimension_value(logits.shape[-1]) if k is None: k = tf.shape(logits)[-1] return draw_sample(n, k, logits, n_draws, self.dtype, seed)
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError( "Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value(param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def _forward(self, x): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(x), self._event_shape_in, self._event_shape_out, self.validate_args) y = tf.reshape(x, output_shape) tensorshape_util.set_shape(y, output_tensorshape) return y
def _std_var_helper(self, statistic, statistic_name, statistic_ndims, df_factor_fn): """Helper to compute stddev, covariance and variance.""" df = tf.reshape( self.df, tf.concat([ tf.shape(self.df), tf.ones([statistic_ndims], dtype=tf.int32) ], -1)) # We need to put the tf.where inside the outer tf1.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) statistic = statistic * df_factor_fn(df / denom) # When 1 < df <= 2, stddev/variance are infinite. result_where_defined = tf.where( df > 2., statistic, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), df, message='{} not defined for components of df <= 1.'. format(statistic_name.capitalize())), ]): return tf.identity(result_where_defined)
def _inverse(self, y): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(y), self._event_shape_out, self._event_shape_in, self.validate_args) x = tf.reshape(y, output_shape) tensorshape_util.set_shape(x, output_tensorshape) return x
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _expand_mix_distribution_probs(self): p = self.mixture_distribution.probs_parameter() # [B, deg] deg = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(p.shape, 1)[-1]) if deg is None: deg = tf.shape(p)[-1] event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_ndims = tf.shape(self.event_shape_tensor())[0] expand_shape = tf.concat([ self.mixture_distribution.batch_shape_tensor(), tf.ones([event_ndims], dtype=tf.int32), [deg], ], axis=0) return tf.reshape(p, shape=expand_shape)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat([ self._batch_shape_tensor(samples), self._event_shape_tensor(samples) ], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = tf.shape(input=param) insert_ones = tf.ones( [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)], dtype=param_shape.dtype) new_param_shape = tf.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = tf.where(is_broadcast, 0, start) if stop is not None: stop = tf.where(is_broadcast, 1, stop) if step is not None: step = tf.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(tf.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(param_slices)
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): shape = tf.shape(x) return tf.linalg.triangular_solve(x, tf.eye(shape[-1], batch_shape=shape[:-2], dtype=x.dtype), lower=True)
def _entropy(self): concentration = tf.convert_to_tensor(self.concentration) k = tf.cast(tf.shape(concentration)[-1], self.dtype) total_concentration = tf.reduce_sum(concentration, axis=-1) return (tf.math.lbeta(concentration) + ((total_concentration - k) * tf.math.digamma(total_concentration)) - tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration), axis=-1))
def _event_size(self, param=None): if param is None: param = self._logits if self._logits is not None else self._probs if param.shape is not None: event_size = tf.compat.dimension_value(param.shape[-1]) if event_size is not None: return event_size return tf.shape(param)[-1]
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast( tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _num_categories(self, x=None): """Scalar `int32` tensor: the number of categories.""" with tf.name_scope('num_categories'): if x is None: x = self._probs if self._logits is None else self._logits num_categories = tf.compat.dimension_value(x.shape[-1]) if num_categories is not None: return num_categories return tf.shape(x)[-1]
def _sample_n(self, n, seed=None): scale = tf.convert_to_tensor(self.scale) shape = tf.concat([[n], tf.shape(scale)], 0) sampled = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) return tf.abs(sampled * scale)
def _shape(input, out_type=tf.int32, name=None): # pylint: disable=redefined-builtin if not hasattr(input, 'shape'): x = np.array(input) input = tf.convert_to_tensor(input) if x.dtype is np.object else x input_shape = tf.TensorShape(input.shape) if tensorshape_util.is_fully_defined(input.shape): return np.array(tensorshape_util.as_list(input_shape)).astype( _numpy_dtype(out_type)) return tf.shape(input, out_type=out_type, name=name)
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(x), self._batch_shape_tensor(low=low, high=high)) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _inverse(self, y): # As specified in the Stan reference manual, the procedure is as follows: # N = y.shape[-1] # z_k = y_k / (1 - sum_{i=1 to k-1} y_i) # x_k = logit(z_k) - log(1 / (N - k)) offset = tf.math.log( tf.cast(tf.range(tf.shape(y)[-1] - 1, 0, delta=-1), dtype=dtype_util.base_dtype(y.dtype))) z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True)) return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset