def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype):
  current_period = np.int32(np.arange(num_steps) / duration) % period
  return np.transpose([
      ps.where(current_period == p,  # pylint: disable=g-complex-comprehension
               ps.ones([], dtype=dtype),
               ps.zeros([], dtype=dtype))
      for p in range(period)])
Exemple #2
0
 def _inverse(self, y):
     ndims = prefer_static.rank(y)
     indices = prefer_static.reshape(prefer_static.add(self.axis, ndims),
                                     shape=[-1, 1])
     num_left, num_right = prefer_static.unstack(self.paddings,
                                                 num=2,
                                                 axis=-1)
     x = tf.slice(y,
                  begin=prefer_static.tensor_scatter_nd_update(
                      prefer_static.zeros(ndims, dtype=tf.int32), indices,
                      num_left),
                  size=prefer_static.tensor_scatter_nd_sub(
                      prefer_static.shape(y), indices,
                      num_left + num_right))
     if not self.validate_args:
         return x
     assertions = [
         assert_util.assert_equal(
             self._forward(x),
             y,
             message=('Argument `y` to `inverse` was not padded with '
                      '`constant_values`.')),
     ]
     with tf.control_dependencies(assertions):
         return tf.identity(x)
Exemple #3
0
 def _init(shape_and_dtype):
     """Allocate TensorArray for storing state and momentum."""
     return [  # pylint: disable=g-complex-comprehension
         prefer_static.zeros(prefer_static.concat(
             [[max(self._write_instruction) + 1], s], axis=0),
                             dtype=d) for (s, d) in shape_and_dtype
     ]
 def init_velocity_state_memory(self, input_tensors):
   """Allocate TensorArray for storing state and momentum."""
   shape_and_dtype = [(ps.shape(x_), x_.dtype) for x_ in input_tensors]
   return [  # pylint: disable=g-complex-comprehension
       ps.zeros(
           ps.concat([[max(self._write_instruction) + 1], s], axis=0),
           dtype=d) for (s, d) in shape_and_dtype
   ]
def _squeeze(x, axis):
    """A version of squeeze that works with dynamic axis."""
    x = tf.convert_to_tensor(x, name='x')
    if axis is None:
        return tf.squeeze(x, axis=None)
    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = axis + ps.zeros([1], dtype=axis.dtype)  # Make axis at least 1d.
    keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
    return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
Exemple #6
0
 def _forward(self, x):
   ndims = ps.rank(x)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   return tf.pad(
       x,
       paddings=ps.tensor_scatter_nd_update(
           ps.zeros([ndims, 2], dtype=tf.int32),
           indices, self.paddings),
       mode=self.mode,
       constant_values=ps.cast(self.constant_values, dtype=x.dtype))
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError('`entropy` is not implemented.')
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError('`entropy` is not implemented when '
                                      '`bijector` is not injective.')
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        override_event_shape = tf.convert_to_tensor(self._override_event_shape)
        override_batch_shape = tf.convert_to_tensor(self._override_batch_shape)
        base_batch_shape_tensor = self.distribution.batch_shape_tensor()
        base_event_shape_tensor = self.distribution.event_shape_tensor()
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape),
                                        dtype=dtype_util.base_dtype(
                                            entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(override_batch_shape),
                base_batch_shape_tensor
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                override_batch_shape,
                prefer_static.ones_like(base_batch_shape_tensor)
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat([
            self._batch_shape_tensor(override_batch_shape,
                                     base_batch_shape_tensor),
            self._event_shape_tensor(override_event_shape,
                                     base_event_shape_tensor)
        ], 0),
                                    dtype=self.dtype)
        event_ndims = (
            tensorshape_util.rank(self.event_shape)  # pylint: disable=g-long-ternary
            if tensorshape_util.rank(self.event_shape) is not None else
            tf.size(
                self._event_shape_tensor(override_event_shape,
                                         base_event_shape_tensor)))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
Exemple #8
0
    def _inverse(self, y):
        ndims = ps.rank(y)
        shifted_y = ps.pad(
            ps.slice(
                y, ps.zeros(ndims, dtype=tf.int32),
                ps.shape(y) -
                ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32)
            ),  # Remove the last entry of y in the chosen dimension.
            paddings=ps.one_hot(
                ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1),
                2,
                dtype=tf.int32
            )  # Insert zeros at the beginning of the chosen dimension.
        )

        return y - shifted_y
Exemple #9
0
 def body_fn(vecs, i):
     # Slice out the vector w.r.t. which we're orthogonalizing the rest.
     vecs_ndims = ps.rank(vecs)
     select_axis = (ps.range(vecs_ndims) == vecs_ndims - 1)
     start = ps.where(select_axis, i, ps.zeros([vecs_ndims], i.dtype))
     size = ps.where(select_axis, 1, ps.shape(vecs))
     u = tf.math.l2_normalize(tf.slice(vecs, start, size), axis=-2)
     # TODO(b/171730305): XLA can't handle this line...
     # u = tf.math.l2_normalize(vecs[..., i, tf.newaxis], axis=-2)
     # Find weights by dotting the d x 1 against the d x n.
     weights = tf.einsum('...dm,...dn->...n', u, vecs)
     # Project out vector `u` from the trailing vectors.
     masked_weights = tf.where(tf.range(n) > i, weights,
                               0.)[..., tf.newaxis, :]
     vecs = vecs - tf.math.multiply_no_nan(u, masked_weights)
     tensorshape_util.set_shape(vecs, vectors.shape)
     return vecs, i + 1
Exemple #10
0
    def _entropy(self):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("entropy is not implemented")
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError("entropy is not implemented when "
                                      "bijector is not injective.")
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy()
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy *= tf.cast(
                tf.reduce_prod(input_tensor=self._override_event_shape),
                dtype=entropy.dtype.base_dtype)
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor()
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                self._override_batch_shape,
                prefer_static.ones_like(self.distribution.batch_shape_tensor())
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat(
            [self.batch_shape_tensor(),
             self.event_shape_tensor()], 0),
                                    dtype=self.dtype)
        event_ndims = (self.event_shape.ndims
                       if self.event_shape.ndims is not None else tf.size(
                           input=self.event_shape_tensor()))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims)

        entropy -= tf.cast(ildj, entropy.dtype)
        entropy.set_shape(self.batch_shape)
        return entropy
Exemple #11
0
def interpolate_backward_differences(backward_differences, order,
                                     step_size_ratio):
    """Updates backward differences when a change in the step size occurs."""
    state_dtype = backward_differences.dtype
    interpolation_matrix_ = interpolation_matrix(state_dtype, order,
                                                 step_size_ratio)
    interpolation_matrix_unit_step_size_ratio = interpolation_matrix(
        state_dtype, order, 1.)
    interpolated_backward_differences_orders_one_to_five = tf.matmul(
        interpolation_matrix_unit_step_size_ratio,
        tf.matmul(interpolation_matrix_,
                  backward_differences[1:MAX_ORDER + 1]))
    interpolated_backward_differences = tf.concat([
        tf.gather(backward_differences, [0]),
        interpolated_backward_differences_orders_one_to_five,
        ps.zeros(ps.stack([2, ps.shape(backward_differences)[1]]),
                 dtype=state_dtype),
    ], 0)
    return interpolated_backward_differences
    def _forward(self, x):
        x = tf.convert_to_tensor(x, name='x')
        batch_shape = ps.shape(x)[:-1]

        # Pad zeros on the top row and right column.
        y = fill_triangular.FillTriangular().forward(x)
        rank = ps.rank(y)
        paddings = ps.concat(
            [ps.zeros([rank - 2, 2], dtype=tf.int32), [[1, 0], [0, 1]]],
            axis=0)
        y = tf.pad(y, paddings)

        # Set diagonal to 1s.
        n = ps.shape(y)[-1]
        diag = tf.ones(ps.concat([batch_shape, [n]], axis=-1), dtype=x.dtype)
        y = tf.linalg.set_diag(y, diag)

        # Normalize each row to have Euclidean (L2) norm 1.
        y /= tf.norm(y, axis=-1)[..., tf.newaxis]
        return y
Exemple #13
0
def _particle_filter_initial_weighted_particles(observations,
                                                observation_fn,
                                                initial_state_prior,
                                                initial_state_proposal,
                                                num_particles,
                                                seed=None):
  """Initialize a set of weighted particles including the first observation."""
  # Initial particles all have the same weight, `1. / num_particles`.
  broadcast_batch_shape = tf.convert_to_tensor(
      functools.reduce(
          ps.broadcast_shape,
          tf.nest.flatten(initial_state_prior.batch_shape_tensor()),
          []), dtype=tf.int32)
  initial_log_weights = ps.zeros(
      ps.concat([[num_particles], broadcast_batch_shape], axis=0),
      dtype=tf.float32) - ps.log(num_particles)

  # Propose an initial state.
  if initial_state_proposal is None:
    initial_state = initial_state_prior.sample(num_particles, seed=seed)
  else:
    initial_state = initial_state_proposal.sample(num_particles, seed=seed)
    initial_log_weights += (initial_state_prior.log_prob(initial_state) -
                            initial_state_proposal.log_prob(initial_state))
    # The initial proposal weights are normalized in expectation, but actually
    # normalizing them reduces variance in the initial marginal
    # likelihood.
    initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0)

  # Return particles weighted by the initial observation.
  return smc_kernel.WeightedParticles(
      particles=initial_state,
      log_weights=initial_log_weights + _compute_observation_log_weights(
          step=0,
          particles=initial_state,
          observations=observations,
          observation_fn=observation_fn))
Exemple #14
0
def make_convolution_transpose_fn_with_subkernels_matrix(
        filter_shape,
        strides,
        padding,
        rank=2,
        dilations=None,
        dtype=tf.int32,
        validate_args=False,
        name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))

        strides = tf.get_static_value(strides)
        if not isinstance(strides, int):
            raise ValueError(
                'Argument `strides` must be a statically known integer.'
                'Saw: {}'.format(strides))

        [
            filter_shape,
            rank,
            _,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        fh, fw = filter_shape
        dh, dw = dilations

        # Determine maximum filter height and filter width of sub-kernels.
        sub_fh = (fh - 1) // strides + 1
        sub_fw = (fw - 1) // strides + 1

        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype)
            j_ind = ps.range(j, fw, delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind

        event_ind = ps.zeros((fh * fw, 2), dtype=dtype)
        _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body,
                                     [tf.zeros([], dtype=dtype), event_ind])

        tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
            fh, stride=strides, dilation=dh, padding=padding)
        tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
            fw, stride=strides, dilation=dw, padding=padding)

        pad_bottom = (tot_pad_bottom - 1) // strides + 1
        pad_top = (tot_pad_top - 1) // strides + 1
        pad_right = (tot_pad_right - 1) // strides + 1
        pad_left = (tot_pad_left - 1) // strides + 1
        padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

        truncate_top = pad_top * strides - tot_pad_top
        truncate_left = pad_left * strides - tot_pad_left

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out

        return op
Exemple #15
0
 def _batched_isotropic_normal_like(state_part):
   return sample.Sample(
       normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
       ps.shape(state_part)[batch_rank:])
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        num_steps_state_history_to_pass=None,
        num_steps_observation_history_to_pass=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            prefer_static.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = prefer_static.zeros(
            prefer_static.concat([[num_particles], broadcast_batch_shape],
                                 axis=0),
            dtype=tf.float32) - prefer_static.log(num_particles)

        # Initialize from the prior, and incorporate the first observation.
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results, num_steps_state_history_to_pass,
                num_timesteps))

        results = tf.nest.map_structure(lambda ta: ta.stack(),
                                        loop_results.accumulated_step_results)
        if num_transitions_per_observation != 1:
            # Return a log-prob for each observed step.
            observed_steps = prefer_static.range(
                0, num_timesteps, num_transitions_per_observation)
            results = results._replace(step_log_marginal_likelihood=tf.gather(
                results.step_log_marginal_likelihood, observed_steps))
        return results
Exemple #17
0
def pad_tensor_with_trailing_zeros(x, num_zeros):
    return tf.pad(
        x,
        ps.concat(
            [ps.zeros([ps.rank(x) - 1, 2], dtype=np.int32), [[0, num_zeros]]],
            axis=0))
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_fn=weighted_resampling.resample_systematic,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        trace_fn=_default_trace_fn,
        step_indices_to_trace=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
    trace_fn: Python `callable` defining the values to be traced at each step.
      It takes a `ParticleFilterStepResults` tuple and returns a structure of
      `Tensor`s. The default function returns
      `(particles, log_weights, parent_indices, step_log_likelihood)`.
    step_indices_to_trace: optional `int` `Tensor` listing, in increasing order,
      the indices of steps at which to record the values traced by `trace_fn`.
      If `None`, the default behavior is to trace at every timestep,
      equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`.
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    incremental_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted)
        # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`.
        if step_indices_to_trace is not None:
            (step_indices_to_trace,
             traced_steps_have_rank_zero) = _canonicalize_steps_to_trace(
                 step_indices_to_trace, num_timesteps)

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            ps.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = ps.zeros(
            ps.concat([[num_particles], broadcast_batch_shape], axis=0),
            dtype=tf.float32) - ps.log(num_particles)

        # Initialize from the prior and incorporate the first observation.
        dummy_previous_step = ParticleFilterStepResults(
            particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            parent_indices=None,
            incremental_log_marginal_likelihood=0.,
            accumulated_log_marginal_likelihood=0.)
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_step_results=dummy_previous_step,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_fn=resample_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_traced_results,
                       num_steps_traced):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                ps.equal(num_transitions_per_observation, 1)
                | ps.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            new_step_results = _filter_one_step(
                step=step,
                previous_step_results=previous_step_results,
                observation=current_observation,
                transition_fn=transition_fn,
                observation_fn=observation_fn,
                proposal_fn=proposal_fn,
                resample_criterion_fn=resample_criterion_fn,
                resample_fn=resample_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(
                step=step,
                current_step_results=new_step_results,
                accumulated_traced_results=accumulated_traced_results,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace,
                num_steps_traced=num_steps_traced)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results=initial_step_results,
                num_timesteps=num_timesteps,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace))

        results = tf.nest.map_structure(
            lambda ta: ta.stack(), loop_results.accumulated_traced_results)
        if step_indices_to_trace is not None:
            # If we were passed a rank-0 (single scalar) step to trace, don't
            # return a time axis in the returned results.
            results = ps.cond(
                traced_steps_have_rank_zero,
                lambda: tf.nest.map_structure(lambda x: x[0, ...], results),
                lambda: results)

        return results