def _sample_n(self, n, seed=None): seed = samplers.sanitize_seed(seed) seed1, seed2 = samplers.split_seed(seed, salt='Skellam') log_rate1 = self._log_rate1_parameter_no_checks() log_rate2 = self._log_rate2_parameter_no_checks() batch_shape = self._batch_shape_tensor( log_rate1=log_rate1, log_rate2=log_rate2) log_rate1 = ps.broadcast_to(log_rate1, batch_shape) log_rate2 = ps.broadcast_to(log_rate2, batch_shape) sample1 = poisson_lib.random_poisson( [n], log_rates=log_rate1, seed=seed1)[0] sample2 = poisson_lib.random_poisson( [n], log_rates=log_rate2, seed=seed2)[0] return sample1 - sample2
def _batch_gather(params, indices, axis=0): """Gathers a batch of indices from `params` along the given axis. Args: params: `Tensor` of shape `[d[0], d[1], ..., d[N - 1]]`. indices: int `Tensor` of shape broadcastable to that of `params`. axis: int `Tensor` dimension of `params` (and of the broadcast indices) to gather over. Returns: result: `Tensor` of the same type and shape as `params`. """ params_rank = prefer_static.rank_from_shape(prefer_static.shape(params)) indices_rank = prefer_static.rank_from_shape(prefer_static.shape(indices)) params_with_axis_on_right = dist_util.move_dimension(params, source_idx=axis, dest_idx=-1) indices_with_axis_on_right = prefer_static.broadcast_to( dist_util.move_dimension(indices, source_idx=axis - (params_rank - indices_rank), dest_idx=-1), prefer_static.shape(params_with_axis_on_right)) result = tf.gather(params_with_axis_on_right, indices_with_axis_on_right, axis=params_rank - 1, batch_dims=params_rank - 1) return dist_util.move_dimension(result, source_idx=-1, dest_idx=axis)
def prepare_tuple_argument(arg, n, arg_name, validate_args=False): """Helper which processes `Tensor`s to tuples in standard form.""" # Short-circuiting incoming lists and tuples here avoids both # Tensor packing / unpacking and numpy 1.20.+ pickiness about # np.array(tuple of Tensor). if isinstance(arg, (tuple, list)): if len(arg) == n: return tuple(arg) if len(arg) == 1: return (arg[0], ) * n arg_size = ps.size(arg) arg_size_ = tf.get_static_value(arg_size) assertions = [] if arg_size_ is not None: if arg_size_ not in (1, n): raise ValueError( 'The size of `{}` must be equal to `1` or to the rank ' 'of the convolution (={}). Saw size = {}'.format( arg_name, n, arg_size_)) elif validate_args: assertions.append( assert_util.assert_equal( ps.logical_or(arg_size == 1, arg_size == n), True, message= ('The size of `{}` must be equal to `1` or to the rank of the ' 'convolution (={})'.format(arg_name, n)))) with tf.control_dependencies(assertions): arg = ps.broadcast_to(arg, shape=[n]) arg = ps.unstack(arg, num=n) return arg
def _sample_n(self, n, seed=None): total_count = tf.cast(self.total_count, tf.int32) low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) return _sample_bates( ps.broadcast_to(total_count, self._batch_shape_tensor()), low, high, n, seed=seed)
def _get_flattened_marginal_distribution(self, index_points=None): # This returns a MVN of event size [N * E], where N is the number of tasks # and E is the number of index points. with self._name_and_control_scope( 'get_flattened_marginal_distribution'): index_points = self._get_index_points(index_points) covariance = self._compute_flattened_covariance(index_points) batch_shape = self._batch_shape_tensor(index_points=index_points) event_shape = self._event_shape_tensor(index_points=index_points) # Now take the cholesky but specialize to cases where we have block-diag # and kronecker. covariance_cholesky = cholesky_util.cholesky_from_fn( covariance, self._cholesky_fn) loc = self._mean_fn(index_points) # Ensure that we broadcast the mean function result to ensure we support # constant mean functions (constant over all tasks, and a constant # per-task) loc = ps.broadcast_to( loc, ps.concat([batch_shape, event_shape], axis=0)) loc = _vec(loc) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=covariance_cholesky, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def _get_flattened_marginal_distribution(self, index_points=None): # This returns a MVN of event size [N * E], where N is the number of tasks # and E is the number of index points. with self._name_and_control_scope( 'get_flattened_marginal_distribution'): index_points = self._get_index_points(index_points) scale = _compute_flattened_scale( kernel=self.kernel, index_points=index_points, cholesky_fn=self._cholesky_fn, observation_noise_variance=self.observation_noise_variance) batch_shape = self._batch_shape_tensor(index_points=index_points) event_shape = self._event_shape_tensor(index_points=index_points) loc = self._mean_fn(index_points) # Ensure that we broadcast the mean function result to ensure we support # constant mean functions (constant over all tasks, and a constant # per-task) loc = ps.broadcast_to( loc, ps.concat([batch_shape, event_shape], axis=0)) loc = _vec(loc) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def _matrix(self, x1, x2): locs = util.pad_shape_with_ones(self.locs, ndims=1, start=-2) slopes = util.pad_shape_with_ones(self.slopes, ndims=1, start=-2) weights_x1 = tf.math.sigmoid( slopes * (self.weight_fn(x1, self.feature_ndims)[..., tf.newaxis] - locs)) weights_x1 = weights_x1[..., tf.newaxis, :] weights_x2 = tf.math.sigmoid( slopes * (self.weight_fn(x2, self.feature_ndims)[..., tf.newaxis] - locs)) weights_x2 = weights_x2[..., tf.newaxis, :, :] initial_weights = (1. - weights_x1) * (1. - weights_x2) initial_weights = tf.concat([ initial_weights, tf.ones_like(initial_weights[..., 0])[..., tf.newaxis] ], axis=-1) end_weights = weights_x1 * weights_x2 end_weights = tf.concat( [tf.ones_like(end_weights[..., 0])[..., tf.newaxis], end_weights], axis=-1) results = [k.matrix(x1, x2)[..., tf.newaxis] for k in self.kernels] broadcasted_shape = distribution_util.get_broadcast_shape(*results) results = tf.concat( [ps.broadcast_to(r, broadcasted_shape) for r in results], axis=-1) return tf.math.reduce_sum(initial_weights * results * end_weights, axis=-1)
def _conditional_mean_fn(self, x): """Conditional mean.""" k_x_obs_linop = self.kernel.matrix_over_all_tasks( x, self._observation_index_points) if self._observations_is_missing is not None: k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix( tf.where(_vec(tf.math.logical_not( self._observations_is_missing))[..., tf.newaxis, :], k_x_obs_linop.to_dense(), tf.zeros([], dtype=k_x_obs_linop.dtype))) mean_x = self.mean_fn(x) # pylint:disable=not-callable batch_shape = self._batch_shape_tensor(index_points=x) event_shape = self._event_shape_tensor(index_points=x) mean_x = ps.broadcast_to(mean_x, ps.concat([batch_shape, event_shape], axis=0)) mean_x = _vec(mean_x) return mean_x + k_x_obs_linop.matvec(self._solve_on_obs)
def prepare_tuple_argument(arg, n, arg_name, validate_args=False): """Helper which processes `Tensor`s to tuples in standard form.""" arg_size = ps.size(arg) arg_size_ = tf.get_static_value(arg_size) assertions = [] if arg_size_ is not None: if arg_size_ not in (1, n): raise ValueError( 'The size of `{}` must be equal to `1` or to the rank ' 'of the convolution (={}). Saw size = {}'.format( arg_name, n, arg_size_)) elif validate_args: assertions.append( assert_util.assert_equal( ps.logical_or(arg_size == 1, arg_size == n), True, message= ('The size of `{}` must be equal to `1` or to the rank of the ' 'convolution (={})'.format(arg_name, n)))) with tf.control_dependencies(assertions): arg = ps.broadcast_to(arg, shape=[n]) arg = ps.unstack(arg, num=n) return arg
def _bates_cdf(total_count, low, high, dtype, value): """Compute the Bates cdf. Internally, the (standard, unnormalized) cdf is computed by the formula ```none pdf = sum_{k=0}^j (-1)^k (n choose k) (nx - k)^n ``` where * `n = total_count`, * `x = value` the value to compute the cumulative probability of, and * `j = floor(nx)`. This is shifted to `[low, high]` and normalized. Since the pdf is symmetric, we have `cdf(x) = 1 - cdf(1 - x)` for `x > .5`, hence we only compute the left half, which keeps the number of terms lower. Computation is batched, using `tf.math.segment_sum()`. For this reason this is not compatible with `tf.vectorized_map()`. All input parameters should have compatible dtypes and shapes. Args: total_count: `Tensor` with integer values, as given to the `Bates` constructor. low: Float `Tensor`, as given to the `Bates` constructor. high: Float `Tensor`, as given to the `Bates` constructor. dtype: The dtype of the output. value: Float `Tensor`. Input value to `cdf()`. Returns: cdf: Float `Tensor`. See above formula. """ total_count = tf.cast(total_count, dtype) low = tf.convert_to_tensor(low) high = tf.convert_to_tensor(high) # Warn the user if they try to compute a pdf with high `total_count`. This # warning is here instead of `_parameter_control_dependencies()` because # nested calls to `_name_and_control_scope` (e.g. `log_survival_function`) can # result in multiple warnings being added and multiple tensor # conversions. Also `sample()` does not have the same numerical issues. with tf.control_dependencies([_stability_limit_tensor(total_count, dtype)]): # Center and adjust `value` using limits and symmetry. value_centered = (value - low) / (high - low) value_adj = tf.clip_by_value(value_centered, 0., 1.) value_adj = tf.where(value_adj < .5, value_adj, 1. - value_adj) value_adj = tf.where(tf.math.is_finite(value_adj), value_adj, low) # Flatten to make segments; need to broadcast before flattening. shape = ps.broadcast_shape(ps.shape(value_adj), ps.shape(total_count)) total_count_b = ps.broadcast_to(total_count, shape) total_count_x_value_adj_b = total_count * value_adj total_count_f = tf.reshape(total_count_b, [-1]) total_count_x_value_adj_f = tf.reshape(total_count_x_value_adj_b, [-1]) # Create segmented terms of summation. num_terms_f = tf.cast(tf.math.floor(total_count_x_value_adj_f + 1), dtype=tf.int32) term_idx_s = tf.cast(_segmented_range(num_terms_f), dtype) # aka `k` total_count_s = tf.repeat(total_count_f, num_terms_f) total_count_x_value_adj_s = tf.repeat(total_count_x_value_adj_f, num_terms_f) terms = (tf.cast(-1., dtype)**term_idx_s * (1. / ((total_count_s + 1.) * tf.math.exp( tfp_math.lbeta(total_count_s - term_idx_s + 1., term_idx_s + 1.)))) * (total_count_x_value_adj_s - term_idx_s)**total_count_s) # Segment sum. segment_ids = tf.repeat(tf.range(tf.size(num_terms_f)), num_terms_f) cdf_s = tf.math.segment_sum(terms, segment_ids) # Reshape back. cdf = tf.reshape(cdf_s, shape) # Normalize. cdf = cdf / tf.math.exp( tf.math.lgamma(total_count_b + tf.cast(1., dtype))) # cdf symmetry adjustment: cdf(x) = 1 - cdf(1 - x) for x > 0.5 cdf = tf.where(value_centered > .5, 1. - cdf, cdf) # Fix out-of-support queries. cdf = tf.where(value_centered < 0., tf.cast(0., dtype), cdf) cdf = tf.where(value_centered > 1., tf.cast(1., dtype), cdf) cdf = tf.where(tf.math.is_finite(value_centered), cdf, np.nan) return cdf
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, has_observation=True, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights, axis=-1) # If this step has an observation, compute its weights and marginal # likelihood (and otherwise, leave weights unchanged). observation_log_weights = prefer_static.cond( has_observation, lambda: prefer_static.broadcast_to( # pylint: disable=g-long-lambda _compute_observation_log_weights(step, proposed_particles, observation, observation_fn), prefer_static.shape(log_weights)), lambda: tf.zeros_like(log_weights)) unnormalized_log_weights = log_weights + observation_log_weights step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=0) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, resample_independent, seed=seed) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, _dummy_indices_like(resample_indices), log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def index_remapping_gather(params, indices, axis=0, indices_axis=0, name='index_remapping_gather'): """Gather values from `axis` of `params` using `indices_axis` of `indices`. The shape of `indices` must broadcast to that of `params` when their `indices_axis` and `axis` (respectively) are aligned: ```python # params.shape: [p[0], ..., ..., p[axis], ..., ..., p[rank(params)] - 1]) # indices.shape: [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1]) ``` In particular, `params` must have at least as many leading dimensions as `indices` (`axis >= indices_axis`), and at least as many trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`). The `result` has the same shape as `params`, except that the dimension of size `p[axis]` is replaced by one of size `i[indices_axis]`: ```python # result.shape: [p[0], ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]] ``` In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and `indices_axis = 1`, the result is given by ```python # alignment is: v axis # params.shape == [p[0], p[1], p[2], p[3], p[4]] # indices.shape == [i[0], i[1], i[2]] # ^ indices_axis result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m] ``` Args: params: `N-D` `Tensor` (`N > 0`) from which to gather values. Number of dimensions must be known statically. indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose shape broadcasts to that of `params` as described above. axis: Python `int` axis of `params` from which to gather. indices_axis: Python `int` axis of `indices` to align with the `axis` over which `params` is gathered. name: String name for scoping created ops. Returns: `Tensor` composed of elements of `params`. Raises: ValueError: If shape/rank requirements are not met. """ with tf.name_scope(name): params = tf.convert_to_tensor(params, name='params') indices = tf.convert_to_tensor(indices, name='indices') params_ndims = tensorshape_util.rank(params.shape) indices_ndims = tensorshape_util.rank(indices.shape) # `axis` dtype must match ndims, which are 64-bit Python ints. axis = tf.get_static_value(ps.convert_to_shape_tensor(axis, dtype=tf.int64)) indices_axis = tf.get_static_value( ps.convert_to_shape_tensor(indices_axis, dtype=tf.int64)) if params_ndims is None: raise ValueError( 'Rank of `params`, must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if axis is None: raise ValueError( '`axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis is None: raise ValueError( '`indices_axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis > axis: raise ValueError( '`indices_axis` should be <= `axis`, but was {} > {}'.format( indices_axis, axis)) if params_ndims < 1: raise ValueError( 'Rank of params should be `> 0`, but was {}'.format(params_ndims)) if indices_ndims is not None and indices_ndims < 1: raise ValueError( 'Rank of indices should be `> 0`, but was {}'.format(indices_ndims)) if (indices_ndims is not None and (indices_ndims - indices_axis > params_ndims - axis)): raise ValueError( '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - ' 'indices_axis` ({} - {}), but was not.'.format( params_ndims, axis, indices_ndims, indices_axis)) # `tf.gather` requires the axis to be the rightmost batch ndim. So, we # transpose `indices_axis` to be the rightmost dimension of `indices`... transposed_indices = dist_util.move_dimension(indices, source_idx=indices_axis, dest_idx=-1) # ... and `axis` to be the corresponding (aligned as in the docstring) # dimension of `params`. broadcast_indices_ndims = indices_ndims + (axis - indices_axis) transposed_params = dist_util.move_dimension( params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1) # Next we broadcast `indices` so that its shape has the same prefix as # `params.shape`. transposed_params_shape = ps.shape(transposed_params) result_shape = ps.concat([ transposed_params_shape[:broadcast_indices_ndims - 1], ps.shape(indices)[indices_axis:indices_axis + 1], transposed_params_shape[broadcast_indices_ndims:]], axis=0) broadcast_indices = ps.broadcast_to( transposed_indices, result_shape[:broadcast_indices_ndims]) result_t = tf.gather(transposed_params, broadcast_indices, batch_dims=broadcast_indices_ndims - 1, axis=broadcast_indices_ndims - 1) return dist_util.move_dimension(result_t, source_idx=broadcast_indices_ndims - 1, dest_idx=axis)