Exemple #1
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = ps.shape(param)
  insert_ones = ps.ones(
      [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = ps.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = ps.where(is_broadcast, 0, start)
      if stop is not None:
        stop = ps.where(is_broadcast, 1, stop)
      if step is not None:
        step = ps.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(ps.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(tuple(param_slices))
def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype):
  current_period = np.int32(np.arange(num_steps) / duration) % period
  return np.transpose([
      ps.where(current_period == p,  # pylint: disable=g-complex-comprehension
               ps.ones([], dtype=dtype),
               ps.zeros([], dtype=dtype))
      for p in range(period)])
Exemple #3
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
def _interleave(a, b, axis):
  """Interleaves two `Tensor`s along the given axis."""
  # [a b c ...] [d e f ...] -> [a d b e c f ...]
  num_elems_a = ps.shape(a)[axis]
  num_elems_b = ps.shape(b)[axis]

  # Note that interleaving implies rank(a)==rank(b).
  axis = ps.where(axis >= 0, axis, ps.rank(a) + axis)
  axis = (int(axis)  # Avoid ndarray values.
          if tf.get_static_value(axis) is not None
          else axis)

  def _interleave_with_b(a):
    return tf.reshape(
        # Work around lack of support for Tensor axes in `tf.stack` by using
        # `concat` and `expand_dims` instead.
        tf.concat([tf.expand_dims(a, axis=axis + 1),
                   tf.expand_dims(b, axis=axis + 1)],
                  axis=axis + 1),
        ps.concat(
            [
                ps.shape(a)[:axis],
                [2 * num_elems_b],
                ps.shape(a)[axis + 1:]
            ],
            axis=0))
  return ps.cond(
      ps.equal(num_elems_a, num_elems_b + 1),
      lambda: tf.concat([  # pylint: disable=g-long-lambda
          _interleave_with_b(_slice_along_axis(a, None, -1, axis=axis)),
          _slice_along_axis(a, -1, None, axis=axis)], axis=axis),
      lambda: _interleave_with_b(a))
Exemple #5
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Exemple #6
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[-1]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)

        observation_log_weights = _compute_observation_log_weights(
            step, proposed_particles, observation, observation_fn)
        unnormalized_log_weights = (log_weights + proposal_log_weights +
                                    observation_log_weights)
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=-1)
        log_weights = (unnormalized_log_weights -
                       step_log_marginal_likelihood[..., tf.newaxis])

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = tf.convert_to_tensor(
            resample_criterion_fn(unnormalized_log_weights))[
                ..., tf.newaxis]  # Broadcast over particles.

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          seed=seed)
        dummy_indices = tf.broadcast_to(prefer_static.range(num_particles),
                                        prefer_static.shape(resample_indices))
        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, dummy_indices, log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
Exemple #7
0
    def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
        if event_shape is None:
            event_shape = self.event_shape_tensor()

        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)

        # Continuing the example from `_augment_sample_shape`, suppose we have:
        #   - sample shape of `[n]`,
        #   - underlying distribution batch shape of `[2, 1]`,
        #   - final broadcast batch shape of `[4, 2, 3]`.
        # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
        # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.

        # First, we reshape to expand out the batch elements:
        # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
        # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
        # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)
        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
        x_with_doubled_batch = tf.reshape(
            x,
            ps.concat([
                sample_shape,
                ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp,
                event_shape
            ],
                      axis=0))

        # Next, construct the permutation that interleaves the batch dimensions,
        # resulting in samples with shape
        # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
        # Note that each interleaved pair of batch dimensions contains exactly one
        # dim of size `1` and one of size `>= 1`.
        sample_ndims = ps.rank_from_shape(sample_shape)
        x_with_interleaved_batch = tf.transpose(
            x_with_doubled_batch,
            perm=ps.concat([
                ps.range(sample_ndims), sample_ndims + ps.reshape(
                    ps.stack([
                        ps.range(batch_rank),
                        ps.range(batch_rank) + batch_rank
                    ],
                             axis=-1), [-1]), sample_ndims + 2 * batch_rank +
                ps.range(ps.rank_from_shape(event_shape))
            ],
                           axis=0))

        # Final reshape to remove the spurious `1` dimensions.
        return tf.reshape(
            x_with_interleaved_batch,
            ps.concat([sample_shape, batch_shape, event_shape], axis=0))
Exemple #8
0
 def body_fn(vecs, i):
     # Slice out the vector w.r.t. which we're orthogonalizing the rest.
     vecs_ndims = ps.rank(vecs)
     select_axis = (ps.range(vecs_ndims) == vecs_ndims - 1)
     start = ps.where(select_axis, i, ps.zeros([vecs_ndims], i.dtype))
     size = ps.where(select_axis, 1, ps.shape(vecs))
     u = tf.math.l2_normalize(tf.slice(vecs, start, size), axis=-2)
     # TODO(b/171730305): XLA can't handle this line...
     # u = tf.math.l2_normalize(vecs[..., i, tf.newaxis], axis=-2)
     # Find weights by dotting the d x 1 against the d x n.
     weights = tf.einsum('...dm,...dn->...n', u, vecs)
     # Project out vector `u` from the trailing vectors.
     masked_weights = tf.where(tf.range(n) > i, weights,
                               0.)[..., tf.newaxis, :]
     vecs = vecs - tf.math.multiply_no_nan(u, masked_weights)
     tensorshape_util.set_shape(vecs, vectors.shape)
     return vecs, i + 1
Exemple #9
0
def _sanitize_slices(slices, intended_shape, deficient_shape):
    """Restricts slices to avoid overflowing size-1 (broadcast) dimensions.

  Args:
    slices: iterable of slices received by `__getitem__`.
    intended_shape: int `Tensor` shape for which the slices were intended.
    deficient_shape: int `Tensor` shape to which the slices will be applied.
      Must have the same rank as `intended_shape`.
  Returns:
    sanitized_slices: Python `list` of
  """
    sanitized_slices = []
    idx = 0
    for slc in slices:
        if slc is Ellipsis:  # Switch over to negative indexing.
            if idx < 0:
                raise ValueError(
                    'Found multiple `...` in slices {}'.format(slices))
            num_remaining_non_newaxis_slices = sum([
                s is not tf.newaxis
                for s in slices[slices.index(Ellipsis) + 1:]
            ])
            idx = -num_remaining_non_newaxis_slices
        elif slc is tf.newaxis:
            pass
        else:
            is_broadcast = intended_shape[idx] > deficient_shape[idx]
            if isinstance(slc, slice):
                # Slices are denoted by start:stop:step.
                start, stop, step = slc.start, slc.stop, slc.step
                if start is not None:
                    start = ps.where(is_broadcast, 0, start)
                if stop is not None:
                    stop = ps.where(is_broadcast, 1, stop)
                if step is not None:
                    step = ps.where(is_broadcast, 1, step)
                slc = slice(start, stop, step)
            else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
                slc = ps.where(is_broadcast, 0, slc)
            idx += 1
        sanitized_slices.append(slc)
    return sanitized_slices
Exemple #10
0
 def _calculate_batch_shape(self):
   """Computes fully defined batch shape for the new distribution."""
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shape = ps.stack(all_batch_shapes, axis=0)
   index_mask = ps.cast(
       ps.one_hot(self._axis, ps.shape(original_shape)[1]),
       dtype=tf.bool)
   new_concat_dim = ps.cast(
       ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32)
   return ps.where(index_mask, new_concat_dim,
                   ps.reduce_max(original_shape, axis=0))
Exemple #11
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                             depth=ndims,
                                             on_value=True,
                                             off_value=False,
                                             dtype=tf.bool),
                                  axis=0)
        in_shape = args_shape
        if not tensorshape_util.is_fully_defined(in_shape):
            in_shape = tf.shape(operands[0])
        unsqueezed_shape = ps.where(axis_nhot, 1, in_shape)

        result = _variadic_reduce_custom_grad(operands, inits, axis, reducer,
                                              unsqueezed_shape)

        if keepdims:
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, unsqueezed_shape), result)
        return result
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps):
    """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc."""
    step_indices_to_trace = tf.convert_to_tensor(
        step_indices_to_trace,
        dtype_hint=tf.int32)  # Warning: breaks gradients.
    traced_steps_have_rank_zero = ps.equal(
        ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0)
    # Canonicalize negative step indices as positive.
    step_indices_to_trace = ps.where(step_indices_to_trace < 0,
                                     num_timesteps + step_indices_to_trace,
                                     step_indices_to_trace)
    # Canonicalize scalars as length-one vectors.
    return (ps.reshape(step_indices_to_trace,
                       [ps.size(step_indices_to_trace)]),
            traced_steps_have_rank_zero)
Exemple #13
0
def _compute_observation_log_weights(step,
                                     particles,
                                     observations,
                                     observation_fn,
                                     num_transitions_per_observation=1):
  """Computes particle importance weights from an observation step.

  Args:
    step: int `Tensor` current step.
    particles: Nested structure of `Tensor`s, each of shape
      `concat([[num_particles, b1, ..., bN], event_shape])`, where
      `b1, ..., bN` are optional batch dimensions and `event_shape` may
      differ across `Tensor`s.
    observations: Nested structure of `Tensor`s, each of shape
      `concat([[num_observations, b1, ..., bN], event_shape])`
      where `b1, ..., bN` are optional batch dimensions and `event_shape` may
      differ across `Tensor`s.
    observation_fn: callable with signature
      `observation_dist = observation_fn(step, particles)`, producing
      a batch of distributions over the `observation` at the given `step`,
      one for each particle.
    num_transitions_per_observation: optional int `Tensor` number of times
      to apply the transition model between successive observation steps.
      Default value: `1`.
  Returns:
    log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`.
  """
  with tf.name_scope('compute_observation_log_weights'):

    step_has_observation = (
        # The second of these conditions subsumes the first, but both are
        # useful because the first can often be evaluated statically.
        ps.equal(num_transitions_per_observation, 1) |
        ps.equal(step % num_transitions_per_observation, 0))
    observation_idx = step // num_transitions_per_observation
    observation = tf.nest.map_structure(
        lambda x, step=step: tf.gather(x, observation_idx), observations)

    log_weights = observation_fn(step, particles).log_prob(observation)
    return ps.where(step_has_observation,
                    log_weights,
                    tf.zeros_like(log_weights))
    def _sample_n(self, n, seed=None):
        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)
        n_batch = ps.reduce_prod(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
        underlying_n_batch = ps.reduce_prod(underlying_batch_shape)

        # Left pad underlying shape with any necessary ones.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)

        # Determine how many underlying samples to produce.
        n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
        samps = self.distribution.sample([n, n_bcast_samples], seed=seed)

        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)

        event_shape = self.event_shape_tensor()
        event_rank = ps.rank_from_shape(event_shape)
        shp = ps.concat([[n],
                         ps.where(is_dim_bcast, batch_shape, 1),
                         underlying_bcast_shp, event_shape],
                        axis=0)
        # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
        samps = tf.reshape(samps, shp)
        # Interleave broadcast and underlying axis indices for transpose.
        interleaved_batch_axes = ps.reshape(
            ps.stack([ps.range(batch_rank),
                      ps.range(batch_rank) + batch_rank],
                     axis=-1), [-1]) + 1

        event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
        perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
        samps = tf.transpose(samps, perm=perm)
        # Finally, reshape to the fully-broadcast batch shape.
        return tf.reshape(samps,
                          ps.concat([[n], batch_shape, event_shape], axis=0))
Exemple #15
0
 def _calculate_new_shape(self):
     # Try to get the old shape statically if available.
     original_shape = self._distribution.batch_shape
     if not tensorshape_util.is_fully_defined(original_shape):
         original_shape = self._distribution.batch_shape_tensor()
     # This is not a check for falseness, it's a check for exactly that shape.
     if original_shape == ():  # pylint: disable=g-explicit-bool-comparison
         # Force the size to be an integer, not a float, when the shape contains no
         # dtype information.
         original_size = 1
     else:
         original_size = ps.reduce_prod(original_shape)
     original_size = ps.cast(original_size, tf.int32)
     # Compute the new shape, filling in the `-1` dimension if present.
     new_shape = self._batch_shape_unexpanded
     implicit_dim_mask = ps.equal(new_shape, -1)
     size_implicit_dim = (original_size //
                          ps.maximum(1, -ps.reduce_prod(new_shape)))
     expanded_new_shape = ps.where(  # Assumes exactly one `-1`.
         implicit_dim_mask, size_implicit_dim, new_shape)
     # Return the original size on the side because one caller would otherwise
     # have to recompute it.
     return expanded_new_shape, original_size
    def _get_conditional_posterior(self, sampler_state):
        """Builds the joint posterior for a sparsity pattern (eqn (7) from [1])."""
        indices = ps.where(sampler_state.nonzeros)[:, 0]
        conditional_posterior_precision_chol = tf.linalg.cholesky(
            tf.gather(tf.gather(sampler_state.weights_posterior_precision,
                                indices),
                      indices,
                      axis=1))
        conditional_weights_mean = tf.linalg.cholesky_solve(
            conditional_posterior_precision_chol,
            tf.gather(sampler_state.x_transpose_y,
                      indices)[..., tf.newaxis])[..., 0]

        @joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched
        def posterior_jd():
            observation_noise_variance = yield InverseGammaWithSampleUpperBound(
                concentration=(
                    self.observation_noise_variance_posterior_concentration),
                scale=sampler_state.observation_noise_variance_posterior_scale,
                upper_bound=self.observation_noise_variance_upper_bound,
                name='observation_noise_variance')

            yield MVNPrecisionFactorHardZeros(
                loc=conditional_weights_mean,
                # Note that the posterior precision varies inversely with the
                # noise variance: in worlds with high noise we're also
                # more uncertain about the values of the weights.
                # TODO(colcarroll): Tests pass even without a square root on the
                # observation_noise_variance. Should add a test that would fail.
                precision_factor=tf.linalg.LinearOperatorLowerTriangular(
                    conditional_posterior_precision_chol /
                    tf.sqrt(observation_noise_variance[..., tf.newaxis,
                                                       tf.newaxis])),
                nonzeros=sampler_state.nonzeros,
                name='weights')

        return posterior_jd
    def _initialize_sampler_state(self, targets, nonzeros,
                                  observation_noise_variance):
        """Precompute quantities needed to sample with given targets.

    This method computes a sampler state (including factorized precision
    matrices) from scratch for a given sparsity pattern. This requires
    time proportional to `num_features**3`. If a sampler state is already
    available for an off-by-one sparsity pattern, the `_flip_feature` method
    (which takes time proportional to `num_features**2`) is
    generally more efficient.

    Args:
      targets: float Tensor regression outputs of shape `[num_outputs]`.
      nonzeros: boolean Tensor vectors of shape `[num_features]`.
      observation_noise_variance: float Tensor of to scale the posterior
        precision.

    Returns:
      sampler_state: instance of `DynamicSpikeSlabSamplerState` collecting
        Tensor quantities relevant to the sampler. See
        `DynamicSpikeSlabSamplerState` for details.
    """
        with tf.name_scope('initialize_sampler_state'):
            targets = tf.convert_to_tensor(targets, dtype=self.dtype)
            nonzeros = tf.convert_to_tensor(nonzeros, dtype=tf.bool)
            indices = ps.where(nonzeros)[:, 0]

            x_transpose_y = tf.linalg.matvec(self.design_matrix,
                                             targets,
                                             adjoint_a=True)

            weights_posterior_precision = self.x_transpose_x + self.weights_prior_precision * observation_noise_variance
            y_transpose_y = tf.reduce_sum(targets**2, axis=-1)
            conditional_prior_precision_chol = tf.linalg.cholesky(
                tf.gather(tf.gather(self.weights_prior_precision, indices),
                          indices,
                          axis=1))
            conditional_posterior_precision_chol = tf.linalg.cholesky(
                tf.gather(tf.gather(weights_posterior_precision, indices),
                          indices,
                          axis=1))
            sub_x_transpose_y = tf.gather(x_transpose_y, indices)
            conditional_weights_mean = tf.linalg.cholesky_solve(
                conditional_posterior_precision_chol,
                sub_x_transpose_y[..., tf.newaxis])[..., 0]
            return self._compute_log_prob(
                x_transpose_y=x_transpose_y,
                y_transpose_y=y_transpose_y,
                nonzeros=nonzeros,
                conditional_prior_precision_chol=
                conditional_prior_precision_chol,
                conditional_posterior_precision_chol=
                conditional_posterior_precision_chol,
                weights_posterior_precision=weights_posterior_precision,
                observation_noise_variance_posterior_scale=(
                    self.observation_noise_variance_prior_scale +  # ss / 2
                    (
                        y_transpose_y -
                        tf.reduce_sum(  # beta_gamma' V_gamma^{-1} beta_gamma
                            conditional_weights_mean * sub_x_transpose_y,
                            axis=-1)) / 2))
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional seed for reproducible sampling.

    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = samplers.sanitize_seed(seed)
                proposal_seed, resample_seed = samplers.split_seed(seed)

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=proposal_seed)
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: tf.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=resample_seed)
                uniform_weights = tf.fill(
                    ps.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood),
                    seed=seed))
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        if JAX_MODE:
            from jax import lax  # pylint: disable=g-import-not-at-top
            result = lax.reduce(operands,
                                init_values=inits,
                                dimensions=axis,
                                computation=reducer)
        elif (tf.executing_eagerly()
              or not control_flow_util.GraphOrParentsInXlaContext(
                  tf1.get_default_graph())):
            result = _variadic_reduce(operands,
                                      init=inits,
                                      axis=axis,
                                      reducer=reducer)
        else:
            result = _xla_reduce(operands, inits, axis)

        if keepdims:
            axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                                 depth=ndims,
                                                 on_value=True,
                                                 off_value=False,
                                                 dtype=tf.bool),
                                      axis=0)
            in_shape = args_shape
            if not tensorshape_util.is_fully_defined(in_shape):
                in_shape = tf.shape(operands[0])
            final_shape = ps.where(axis_nhot, 1, in_shape)
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, final_shape), result)
        return result
Exemple #20
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
  slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[0]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = prefer_static.cond(
            has_observation,
            lambda: prefer_static.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                prefer_static.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=0)
        log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          resample_independent,
                                                          seed=seed)

        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
 def __init__(self, loc, precision_factor, nonzeros, **kwargs):
     self._indices = ps.where(nonzeros)
     self._size = ps.dimension_size(nonzeros, -1)
     super().__init__(loc=loc, precision_factor=precision_factor, **kwargs)