Esempio n. 1
0
 def _sample_n(self, n, seed=None):
   seed = samplers.sanitize_seed(seed)
   seed1, seed2 = samplers.split_seed(seed, salt='Skellam')
   log_rate1 = self._log_rate1_parameter_no_checks()
   log_rate2 = self._log_rate2_parameter_no_checks()
   batch_shape = self._batch_shape_tensor(
       log_rate1=log_rate1, log_rate2=log_rate2)
   log_rate1 = ps.broadcast_to(log_rate1, batch_shape)
   log_rate2 = ps.broadcast_to(log_rate2, batch_shape)
   sample1 = poisson_lib.random_poisson(
       [n], log_rates=log_rate1, seed=seed1)[0]
   sample2 = poisson_lib.random_poisson(
       [n], log_rates=log_rate2, seed=seed2)[0]
   return sample1 - sample2
Esempio n. 2
0
def _batch_gather(params, indices, axis=0):
    """Gathers a batch of indices from `params` along the given axis.

  Args:
    params: `Tensor` of shape `[d[0], d[1], ..., d[N - 1]]`.
    indices: int `Tensor` of shape broadcastable to that of `params`.
    axis: int `Tensor` dimension of `params` (and of the broadcast indices) to
      gather over.
  Returns:
    result: `Tensor` of the same type and shape as `params`.
  """
    params_rank = prefer_static.rank_from_shape(prefer_static.shape(params))
    indices_rank = prefer_static.rank_from_shape(prefer_static.shape(indices))
    params_with_axis_on_right = dist_util.move_dimension(params,
                                                         source_idx=axis,
                                                         dest_idx=-1)
    indices_with_axis_on_right = prefer_static.broadcast_to(
        dist_util.move_dimension(indices,
                                 source_idx=axis -
                                 (params_rank - indices_rank),
                                 dest_idx=-1),
        prefer_static.shape(params_with_axis_on_right))

    result = tf.gather(params_with_axis_on_right,
                       indices_with_axis_on_right,
                       axis=params_rank - 1,
                       batch_dims=params_rank - 1)
    return dist_util.move_dimension(result, source_idx=-1, dest_idx=axis)
Esempio n. 3
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    # Short-circuiting incoming lists and tuples here avoids both
    # Tensor packing / unpacking and numpy 1.20.+ pickiness about
    # np.array(tuple of Tensor).
    if isinstance(arg, (tuple, list)):
        if len(arg) == n:
            return tuple(arg)
        if len(arg) == 1:
            return (arg[0], ) * n

    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
Esempio n. 4
0
 def _sample_n(self, n, seed=None):
   total_count = tf.cast(self.total_count, tf.int32)
   low = tf.convert_to_tensor(self.low)
   high = tf.convert_to_tensor(self.high)
   return _sample_bates(
       ps.broadcast_to(total_count, self._batch_shape_tensor()),
       low, high, n, seed=seed)
    def _get_flattened_marginal_distribution(self, index_points=None):
        # This returns a MVN of event size [N * E], where N is the number of tasks
        # and E is the number of index points.
        with self._name_and_control_scope(
                'get_flattened_marginal_distribution'):
            index_points = self._get_index_points(index_points)
            covariance = self._compute_flattened_covariance(index_points)

            batch_shape = self._batch_shape_tensor(index_points=index_points)
            event_shape = self._event_shape_tensor(index_points=index_points)

            # Now take the cholesky but specialize to cases where we have block-diag
            # and kronecker.
            covariance_cholesky = cholesky_util.cholesky_from_fn(
                covariance, self._cholesky_fn)
            loc = self._mean_fn(index_points)
            # Ensure that we broadcast the mean function result to ensure we support
            # constant mean functions (constant over all tasks, and a constant
            # per-task)
            loc = ps.broadcast_to(
                loc, ps.concat([batch_shape, event_shape], axis=0))
            loc = _vec(loc)
            return mvn_linear_operator.MultivariateNormalLinearOperator(
                loc=loc,
                scale=covariance_cholesky,
                validate_args=self._validate_args,
                allow_nan_stats=self._allow_nan_stats,
                name='marginal_distribution')
Esempio n. 6
0
    def _get_flattened_marginal_distribution(self, index_points=None):
        # This returns a MVN of event size [N * E], where N is the number of tasks
        # and E is the number of index points.
        with self._name_and_control_scope(
                'get_flattened_marginal_distribution'):
            index_points = self._get_index_points(index_points)
            scale = _compute_flattened_scale(
                kernel=self.kernel,
                index_points=index_points,
                cholesky_fn=self._cholesky_fn,
                observation_noise_variance=self.observation_noise_variance)

            batch_shape = self._batch_shape_tensor(index_points=index_points)
            event_shape = self._event_shape_tensor(index_points=index_points)

            loc = self._mean_fn(index_points)
            # Ensure that we broadcast the mean function result to ensure we support
            # constant mean functions (constant over all tasks, and a constant
            # per-task)
            loc = ps.broadcast_to(
                loc, ps.concat([batch_shape, event_shape], axis=0))
            loc = _vec(loc)
            return mvn_linear_operator.MultivariateNormalLinearOperator(
                loc=loc,
                scale=scale,
                validate_args=self._validate_args,
                allow_nan_stats=self._allow_nan_stats,
                name='marginal_distribution')
Esempio n. 7
0
    def _matrix(self, x1, x2):
        locs = util.pad_shape_with_ones(self.locs, ndims=1, start=-2)
        slopes = util.pad_shape_with_ones(self.slopes, ndims=1, start=-2)

        weights_x1 = tf.math.sigmoid(
            slopes *
            (self.weight_fn(x1, self.feature_ndims)[..., tf.newaxis] - locs))
        weights_x1 = weights_x1[..., tf.newaxis, :]
        weights_x2 = tf.math.sigmoid(
            slopes *
            (self.weight_fn(x2, self.feature_ndims)[..., tf.newaxis] - locs))
        weights_x2 = weights_x2[..., tf.newaxis, :, :]

        initial_weights = (1. - weights_x1) * (1. - weights_x2)
        initial_weights = tf.concat([
            initial_weights,
            tf.ones_like(initial_weights[..., 0])[..., tf.newaxis]
        ],
                                    axis=-1)
        end_weights = weights_x1 * weights_x2
        end_weights = tf.concat(
            [tf.ones_like(end_weights[..., 0])[..., tf.newaxis], end_weights],
            axis=-1)

        results = [k.matrix(x1, x2)[..., tf.newaxis] for k in self.kernels]
        broadcasted_shape = distribution_util.get_broadcast_shape(*results)
        results = tf.concat(
            [ps.broadcast_to(r, broadcasted_shape) for r in results], axis=-1)
        return tf.math.reduce_sum(initial_weights * results * end_weights,
                                  axis=-1)
  def _conditional_mean_fn(self, x):
    """Conditional mean."""
    k_x_obs_linop = self.kernel.matrix_over_all_tasks(
        x, self._observation_index_points)
    if self._observations_is_missing is not None:
      k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix(
          tf.where(_vec(tf.math.logical_not(
              self._observations_is_missing))[..., tf.newaxis, :],
                   k_x_obs_linop.to_dense(),
                   tf.zeros([], dtype=k_x_obs_linop.dtype)))

    mean_x = self.mean_fn(x)  # pylint:disable=not-callable
    batch_shape = self._batch_shape_tensor(index_points=x)
    event_shape = self._event_shape_tensor(index_points=x)
    mean_x = ps.broadcast_to(mean_x,
                             ps.concat([batch_shape, event_shape], axis=0))
    mean_x = _vec(mean_x)
    return mean_x + k_x_obs_linop.matvec(self._solve_on_obs)
Esempio n. 9
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
Esempio n. 10
0
def _bates_cdf(total_count, low, high, dtype, value):
    """Compute the Bates cdf.

  Internally, the (standard, unnormalized) cdf is computed by the formula

  ```none
  pdf = sum_{k=0}^j (-1)^k (n choose k) (nx - k)^n
  ```

  where
  * `n = total_count`,
  * `x = value` the value to compute the cumulative probability of, and
  * `j = floor(nx)`.

  This is shifted to `[low, high]` and normalized. Since the pdf is symmetric,
  we have `cdf(x) = 1 - cdf(1 - x)` for `x > .5`, hence we only compute the left
  half, which keeps the number of terms lower.

  Computation is batched, using `tf.math.segment_sum()`. For this reason this is
  not compatible with `tf.vectorized_map()`.

  All input parameters should have compatible dtypes and shapes.

  Args:
    total_count: `Tensor` with integer values, as given to the `Bates`
      constructor.
    low: Float `Tensor`, as given to the `Bates` constructor.
    high: Float `Tensor`, as given to the `Bates` constructor.
    dtype: The dtype of the output.
    value: Float `Tensor`. Input value to `cdf()`.
  Returns:
    cdf: Float `Tensor`. See above formula.
  """
    total_count = tf.cast(total_count, dtype)
    low = tf.convert_to_tensor(low)
    high = tf.convert_to_tensor(high)

    # Warn the user if they try to compute a pdf with high `total_count`.  This
    # warning is here instead of `_parameter_control_dependencies()` because
    # nested calls to `_name_and_control_scope` (e.g. `log_survival_function`) can
    # result in multiple warnings being added and multiple tensor
    # conversions. Also `sample()` does not have the same numerical issues.
    with tf.control_dependencies([_stability_limit_tensor(total_count,
                                                          dtype)]):
        # Center and adjust `value` using limits and symmetry.
        value_centered = (value - low) / (high - low)
        value_adj = tf.clip_by_value(value_centered, 0., 1.)
        value_adj = tf.where(value_adj < .5, value_adj, 1. - value_adj)
        value_adj = tf.where(tf.math.is_finite(value_adj), value_adj, low)
        # Flatten to make segments; need to broadcast before flattening.
        shape = ps.broadcast_shape(ps.shape(value_adj), ps.shape(total_count))
        total_count_b = ps.broadcast_to(total_count, shape)
        total_count_x_value_adj_b = total_count * value_adj
        total_count_f = tf.reshape(total_count_b, [-1])
        total_count_x_value_adj_f = tf.reshape(total_count_x_value_adj_b, [-1])
        # Create segmented terms of summation.
        num_terms_f = tf.cast(tf.math.floor(total_count_x_value_adj_f + 1),
                              dtype=tf.int32)
        term_idx_s = tf.cast(_segmented_range(num_terms_f), dtype)  # aka `k`
        total_count_s = tf.repeat(total_count_f, num_terms_f)
        total_count_x_value_adj_s = tf.repeat(total_count_x_value_adj_f,
                                              num_terms_f)
        terms = (tf.cast(-1., dtype)**term_idx_s *
                 (1. / ((total_count_s + 1.) * tf.math.exp(
                     tfp_math.lbeta(total_count_s - term_idx_s + 1.,
                                    term_idx_s + 1.)))) *
                 (total_count_x_value_adj_s - term_idx_s)**total_count_s)
        # Segment sum.
        segment_ids = tf.repeat(tf.range(tf.size(num_terms_f)), num_terms_f)
        cdf_s = tf.math.segment_sum(terms, segment_ids)
        # Reshape back.
        cdf = tf.reshape(cdf_s, shape)
        # Normalize.
        cdf = cdf / tf.math.exp(
            tf.math.lgamma(total_count_b + tf.cast(1., dtype)))
        # cdf symmetry adjustment: cdf(x) = 1 - cdf(1 - x) for x > 0.5
        cdf = tf.where(value_centered > .5, 1. - cdf, cdf)
        # Fix out-of-support queries.
        cdf = tf.where(value_centered < 0., tf.cast(0., dtype), cdf)
        cdf = tf.where(value_centered > 1., tf.cast(1., dtype), cdf)
        cdf = tf.where(tf.math.is_finite(value_centered), cdf, np.nan)
        return cdf
Esempio n. 11
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[0]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = prefer_static.cond(
            has_observation,
            lambda: prefer_static.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                prefer_static.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=0)
        log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          resample_independent,
                                                          seed=seed)

        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
Esempio n. 12
0
def index_remapping_gather(params,
                           indices,
                           axis=0,
                           indices_axis=0,
                           name='index_remapping_gather'):
  """Gather values from `axis` of `params` using `indices_axis` of `indices`.

  The shape of `indices` must broadcast to that of `params` when
  their `indices_axis` and `axis` (respectively) are aligned:

  ```python
  # params.shape:
  [p[0],  ..., ...,         p[axis], ..., ..., p[rank(params)] - 1])
  # indices.shape:
        [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1])
  ```

  In particular, `params` must have at least as many
  leading dimensions as `indices` (`axis >= indices_axis`), and at least as many
  trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`).

  The `result` has the same shape as `params`, except that the dimension
  of size `p[axis]` is replaced by one of size `i[indices_axis]`:

  ```python
  # result.shape:
  [p[0],  ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]]
  ```

  In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and
  `indices_axis = 1`, the result is given by

   ```python
   # alignment is:                       v axis
   # params.shape    ==   [p[0], p[1], p[2], p[3], p[4]]
   # indices.shape   ==         [i[0], i[1], i[2]]
   #                                     ^ indices_axis
   result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m]
  ```

  Args:
    params:  `N-D` `Tensor` (`N > 0`) from which to gather values.
      Number of dimensions must be known statically.
    indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose
      shape broadcasts to that of `params` as described above.
    axis: Python `int` axis of `params` from which to gather.
    indices_axis: Python `int` axis of `indices` to align with the `axis`
      over which `params` is gathered.
    name: String name for scoping created ops.

  Returns:
    `Tensor` composed of elements of `params`.

  Raises:
    ValueError: If shape/rank requirements are not met.
  """
  with tf.name_scope(name):
    params = tf.convert_to_tensor(params, name='params')
    indices = tf.convert_to_tensor(indices, name='indices')

    params_ndims = tensorshape_util.rank(params.shape)
    indices_ndims = tensorshape_util.rank(indices.shape)
    # `axis` dtype must match ndims, which are 64-bit Python ints.
    axis = tf.get_static_value(ps.convert_to_shape_tensor(axis, dtype=tf.int64))
    indices_axis = tf.get_static_value(
        ps.convert_to_shape_tensor(indices_axis, dtype=tf.int64))

    if params_ndims is None:
      raise ValueError(
          'Rank of `params`, must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if axis is None:
      raise ValueError(
          '`axis` must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if indices_axis is None:
      raise ValueError(
          '`indices_axis` must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if indices_axis > axis:
      raise ValueError(
          '`indices_axis` should be <= `axis`, but was {} > {}'.format(
              indices_axis, axis))

    if params_ndims < 1:
      raise ValueError(
          'Rank of params should be `> 0`, but was {}'.format(params_ndims))

    if indices_ndims is not None and indices_ndims < 1:
      raise ValueError(
          'Rank of indices should be `> 0`, but was {}'.format(indices_ndims))

    if (indices_ndims is not None and
        (indices_ndims - indices_axis > params_ndims - axis)):
      raise ValueError(
          '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - '
          'indices_axis` ({} - {}), but was not.'.format(
              params_ndims, axis, indices_ndims, indices_axis))

    # `tf.gather` requires the axis to be the rightmost batch ndim. So, we
    # transpose `indices_axis` to be the rightmost dimension of `indices`...
    transposed_indices = dist_util.move_dimension(indices,
                                                  source_idx=indices_axis,
                                                  dest_idx=-1)

    # ... and `axis` to be the corresponding (aligned as in the docstring)
    # dimension of `params`.
    broadcast_indices_ndims = indices_ndims + (axis - indices_axis)
    transposed_params = dist_util.move_dimension(
        params,
        source_idx=axis,
        dest_idx=broadcast_indices_ndims - 1)

    # Next we broadcast `indices` so that its shape has the same prefix as
    # `params.shape`.
    transposed_params_shape = ps.shape(transposed_params)
    result_shape = ps.concat([
        transposed_params_shape[:broadcast_indices_ndims - 1],
        ps.shape(indices)[indices_axis:indices_axis + 1],
        transposed_params_shape[broadcast_indices_ndims:]], axis=0)
    broadcast_indices = ps.broadcast_to(
        transposed_indices,
        result_shape[:broadcast_indices_ndims])

    result_t = tf.gather(transposed_params,
                         broadcast_indices,
                         batch_dims=broadcast_indices_ndims - 1,
                         axis=broadcast_indices_ndims - 1)
    return dist_util.move_dimension(result_t,
                                    source_idx=broadcast_indices_ndims - 1,
                                    dest_idx=axis)