Esempio n. 1
0
        def _loop_body(step, previous_step_results, accumulated_traced_results,
                       num_steps_traced):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                ps.equal(num_transitions_per_observation, 1)
                | ps.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            new_step_results = _filter_one_step(
                step=step,
                previous_step_results=previous_step_results,
                observation=current_observation,
                transition_fn=transition_fn,
                observation_fn=observation_fn,
                proposal_fn=proposal_fn,
                resample_criterion_fn=resample_criterion_fn,
                resample_fn=resample_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(
                step=step,
                current_step_results=new_step_results,
                accumulated_traced_results=accumulated_traced_results,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace,
                num_steps_traced=num_steps_traced)
  def _scan(level, elems):
    """Perform scan on `elems`."""
    elem_length = prefer_static.shape(elems[0])[0]

    # Apply `fn` to reduce adjacent pairs to a single entry.
    a = [elem[0:-1:2] for elem in elems]
    b = [elem[1::2] for elem in elems]
    reduced_elems = lowered_fn(a, b)

    def handle_base_case_elem_length_two():
      return [tf.concat([elem[0:1], reduced_elem], axis=0)
              for (reduced_elem, elem) in zip(reduced_elems, elems)]

    def handle_base_case_elem_length_three():
      reduced_reduced_elems = lowered_fn(
          reduced_elems, [elem[2:3] for elem in elems])
      return [
          tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem], axis=0)
          for (reduced_reduced_elem, reduced_elem, elem)
          in zip(reduced_reduced_elems, reduced_elems, elems)]

    # Base case of recursion: assumes `elem_length` is 2 or 3.
    at_base_case = prefer_static.logical_or(
        prefer_static.equal(elem_length, 2),
        prefer_static.equal(elem_length, 3))
    base_value = lambda: prefer_static.cond(  # pylint: disable=g-long-lambda
        prefer_static.equal(elem_length, 2),
        handle_base_case_elem_length_two,
        handle_base_case_elem_length_three)

    if level <= 0:
      return base_value()

    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      results = prefer_static.cond(
          prefer_static.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([elem[0:1], result], axis=0)
                    for (elem, result) in zip(elems, results)]
      return list(map(_interleave, even_elems, odd_elems))

    return prefer_static.cond(at_base_case, base_value, recursive_case)
Esempio n. 3
0
def _initialize_loop_variables(initial_step_results, num_timesteps, trace_fn,
                               step_indices_to_trace):
    """Initialize arrays and other quantities passed through the filter loop."""

    # Create arrays to store traced values (particles, likelihoods, etc).
    num_steps_to_trace = (num_timesteps if step_indices_to_trace is None else
                          ps.size0(step_indices_to_trace))
    traced_results = trace_fn(initial_step_results)
    trace_arrays = tf.nest.map_structure(
        lambda x: tf.TensorArray(dtype=x.dtype, size=num_steps_to_trace),
        traced_results)
    # If we are supposed to trace at step 0, write the traced values.
    num_steps_traced, trace_arrays = ps.cond(
        (True if step_indices_to_trace is None else ps.equal(
            step_indices_to_trace[0], 0)),
        lambda: (
            1,  # pylint: disable=g-long-lambda
            tf.nest.map_structure(lambda ta, x: ta.write(0, x), trace_arrays,
                                  traced_results)),
        lambda: (0, trace_arrays))

    return ParticleFilterLoopVariables(
        step=1,
        previous_step_results=initial_step_results,
        accumulated_traced_results=trace_arrays,
        num_steps_traced=num_steps_traced)
Esempio n. 4
0
def _interleave(a, b, axis):
  """Interleaves two `Tensor`s along the given axis."""
  # [a b c ...] [d e f ...] -> [a d b e c f ...]
  num_elems_a = ps.shape(a)[axis]
  num_elems_b = ps.shape(b)[axis]

  # Note that interleaving implies rank(a)==rank(b).
  axis = ps.where(axis >= 0, axis, ps.rank(a) + axis)
  axis = (int(axis)  # Avoid ndarray values.
          if tf.get_static_value(axis) is not None
          else axis)

  def _interleave_with_b(a):
    return tf.reshape(
        # Work around lack of support for Tensor axes in `tf.stack` by using
        # `concat` and `expand_dims` instead.
        tf.concat([tf.expand_dims(a, axis=axis + 1),
                   tf.expand_dims(b, axis=axis + 1)],
                  axis=axis + 1),
        ps.concat(
            [
                ps.shape(a)[:axis],
                [2 * num_elems_b],
                ps.shape(a)[axis + 1:]
            ],
            axis=0))
  return ps.cond(
      ps.equal(num_elems_a, num_elems_b + 1),
      lambda: tf.concat([  # pylint: disable=g-long-lambda
          _interleave_with_b(_slice_along_axis(a, None, -1, axis=axis)),
          _slice_along_axis(a, -1, None, axis=axis)], axis=axis),
      lambda: _interleave_with_b(a))
Esempio n. 5
0
    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn(
            [slice_elem(odd_elem, 0, -1) for odd_elem in odd_elems],
            [slice_elem(elem, 2, None, 2) for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [slice_elem(elem, 2, None, 2) for elem in elems])

      results = ps.cond(
          ps.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([slice_elem(elem, 0, 1), result], axis=axis)
                    for (elem, result) in zip(elems, results)]
      return list(map(lambda a, b: _interleave(a, b, axis=axis),
                      even_elems,
                      odd_elems))
Esempio n. 6
0
def _update_loop_variables(step, current_step_results,
                           accumulated_traced_results, trace_fn,
                           step_indices_to_trace, num_steps_traced):
    """Update the loop state to reflect a step of filtering."""

    # Write particles, indices, and likelihoods to their respective arrays.
    trace_this_step = True
    if step_indices_to_trace is not None:
        trace_this_step = ps.equal(
            step_indices_to_trace[ps.minimum(
                num_steps_traced,
                ps.cast(ps.size0(step_indices_to_trace) - 1, dtype=np.int32))],
            step)
    num_steps_traced, accumulated_traced_results = ps.cond(
        trace_this_step,
        lambda: (
            num_steps_traced + 1,  # pylint: disable=g-long-lambda
            tf.nest.map_structure(lambda x, y: x.write(num_steps_traced, y),
                                  accumulated_traced_results,
                                  trace_fn(current_step_results))),
        lambda: (num_steps_traced, accumulated_traced_results))

    return ParticleFilterLoopVariables(
        step=step + 1,
        previous_step_results=current_step_results,
        accumulated_traced_results=accumulated_traced_results,
        num_steps_traced=num_steps_traced)
Esempio n. 7
0
 def _is_increasing(self, **kwargs):
   # desc(desc)=>asc, asc(asc)=>asc, other cases=>desc.
   is_increasing = True
   for b in self._bijectors:
     is_increasing = ps.equal(
         is_increasing, b._internal_is_increasing(**kwargs.get(b.name, {})))  # pylint: disable=protected-access
   return is_increasing
Esempio n. 8
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         # Avoid computing intermediates needed to construct the assertions.
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
         implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1)
         assertions.append(
             assert_util.assert_rank(self._batch_shape_unexpanded,
                                     1,
                                     message='New shape must be a vector.'))
         assertions.append(
             assert_util.assert_less_equal(
                 tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32),
                 1,
                 message='At most one dimension can be unknown.'))
         assertions.append(
             assert_util.assert_non_negative(
                 self._batch_shape_unexpanded + 1,
                 message='Shape elements must be >=-1.'))
         # Check that the old and new shapes are the same size.
         expanded_new_shape, original_size = self._calculate_new_shape()
         new_size = ps.reduce_prod(expanded_new_shape)
         assertions.append(
             assert_util.assert_equal(new_size,
                                      tf.cast(original_size,
                                              new_size.dtype),
                                      message='Shape sizes do not match.'))
     return assertions
Esempio n. 9
0
  def test_step_indices_to_trace(self):
    num_particles = 1024
    (particles_1_3,
     log_weights_1_3,
     parent_indices_1_3,
     incremental_log_marginal_likelihood_1_3) = self.evaluate(
         tfp.experimental.mcmc.particle_filter(
             observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
             initial_state_prior=tfd.Normal(0., 1.),
             transition_fn=lambda _, state: tfd.Normal(state, 10.),
             observation_fn=lambda _, state: tfd.Normal(state, 0.1),
             num_particles=num_particles,
             trace_criterion_fn=lambda s, r: ps.logical_or(  # pylint: disable=g-long-lambda
                 ps.equal(r.steps, 2),
                 ps.equal(r.steps, 4)),
             static_trace_allocation_size=2,
             seed=test_util.test_seed()))
    self.assertLen(particles_1_3, 2)
    self.assertLen(log_weights_1_3, 2)
    self.assertLen(parent_indices_1_3, 2)
    self.assertLen(incremental_log_marginal_likelihood_1_3, 2)
    means = np.sum(np.exp(log_weights_1_3) * particles_1_3, axis=1)
    self.assertAllClose(means, [3., 7.], atol=1.)

    (final_particles,
     final_log_weights,
     final_cumulative_lp) = self.evaluate(
         tfp.experimental.mcmc.particle_filter(
             observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
             initial_state_prior=tfd.Normal(0., 1.),
             transition_fn=lambda _, state: tfd.Normal(state, 10.),
             observation_fn=lambda _, state: tfd.Normal(state, 0.1),
             num_particles=num_particles,
             trace_fn=lambda s, r: (s.particles,  # pylint: disable=g-long-lambda
                                    s.log_weights,
                                    r.accumulated_log_marginal_likelihood),
             trace_criterion_fn=None,
             seed=test_util.test_seed()))
    self.assertLen(final_particles, num_particles)
    self.assertLen(final_log_weights, num_particles)
    self.assertEqual(final_cumulative_lp.shape, ())
    means = np.sum(np.exp(final_log_weights) * final_particles)
    self.assertAllClose(means, 9., atol=1.5)
Esempio n. 10
0
        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)
Esempio n. 11
0
def _compute_observation_log_weights(step,
                                     particles,
                                     observations,
                                     observation_fn,
                                     num_transitions_per_observation=1):
  """Computes particle importance weights from an observation step.

  Args:
    step: int `Tensor` current step.
    particles: Nested structure of `Tensor`s, each of shape
      `concat([[num_particles, b1, ..., bN], event_shape])`, where
      `b1, ..., bN` are optional batch dimensions and `event_shape` may
      differ across `Tensor`s.
    observations: Nested structure of `Tensor`s, each of shape
      `concat([[num_observations, b1, ..., bN], event_shape])`
      where `b1, ..., bN` are optional batch dimensions and `event_shape` may
      differ across `Tensor`s.
    observation_fn: callable with signature
      `observation_dist = observation_fn(step, particles)`, producing
      a batch of distributions over the `observation` at the given `step`,
      one for each particle.
    num_transitions_per_observation: optional int `Tensor` number of times
      to apply the transition model between successive observation steps.
      Default value: `1`.
  Returns:
    log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`.
  """
  with tf.name_scope('compute_observation_log_weights'):

    step_has_observation = (
        # The second of these conditions subsumes the first, but both are
        # useful because the first can often be evaluated statically.
        ps.equal(num_transitions_per_observation, 1) |
        ps.equal(step % num_transitions_per_observation, 0))
    observation_idx = step // num_transitions_per_observation
    observation = tf.nest.map_structure(
        lambda x, step=step: tf.gather(x, observation_idx), observations)

    log_weights = observation_fn(step, particles).log_prob(observation)
    return ps.where(step_has_observation,
                    log_weights,
                    tf.zeros_like(log_weights))
Esempio n. 12
0
def _interleave(a, b):
  """Interleaves two `Tensor`s along their first axis."""
  # [a b c ...] [d e f ...] -> [a d b e c f ...]
  num_elems_a = prefer_static.shape(a)[0]
  num_elems_b = prefer_static.shape(b)[0]

  def _interleave_with_b(a):
    return tf.reshape(
        tf.stack([a, b], axis=1),
        prefer_static.concat([[2 * num_elems_b],
                              prefer_static.shape(a)[1:]], axis=0))
  return prefer_static.cond(
      prefer_static.equal(num_elems_a, num_elems_b + 1),
      lambda: tf.concat([_interleave_with_b(a[:-1]), a[-1:]], axis=0),
      lambda: _interleave_with_b(a))
Esempio n. 13
0
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps):
    """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc."""
    step_indices_to_trace = tf.convert_to_tensor(
        step_indices_to_trace,
        dtype_hint=tf.int32)  # Warning: breaks gradients.
    traced_steps_have_rank_zero = ps.equal(
        ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0)
    # Canonicalize negative step indices as positive.
    step_indices_to_trace = ps.where(step_indices_to_trace < 0,
                                     num_timesteps + step_indices_to_trace,
                                     step_indices_to_trace)
    # Canonicalize scalars as length-one vectors.
    return (ps.reshape(step_indices_to_trace,
                       [ps.size(step_indices_to_trace)]),
            traced_steps_have_rank_zero)
Esempio n. 14
0
def _validate_elem_length(max_num_levels, elems_flat):
  """Checks that elems all have the same length, and returns that length."""
  assertions = []

  elem_length = prefer_static.shape(elems_flat[0])[0]

  # The default size limit will overflow a 32-bit int, so make sure we're
  # using 64-bit.
  size_limit = 2**(prefer_static.cast(max_num_levels, np.int64) + 1)
  enough_levels = prefer_static.less(
      prefer_static.cast(elem_length, np.int64), size_limit)
  enough_levels_ = tf.get_static_value(enough_levels)
  if enough_levels_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            enough_levels, True,
            message='Input `Tensor`s must have first axis dimension less than'
            ' `2**(max_num_levels + 1)`'
            ' (saw: {} which is not less than 2**{} == {})'.format(
                elem_length,
                max_num_levels,
                size_limit)))
  elif not enough_levels_:
    raise ValueError(
        'Input `Tensor`s must have first axis dimension less than'
        ' `2**(max_num_levels + 1)`'
        ' (saw: {} which is not less than 2**{} == {})'.format(
            elem_length,
            max_num_levels,
            size_limit))

  is_consistent = prefer_static.reduce_all([
      prefer_static.equal(
          prefer_static.shape(elem)[0], elem_length)
      for elem in elems_flat[1:]])

  is_consistent_ = tf.get_static_value(is_consistent)
  if is_consistent_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            is_consistent, True,
            message='Input `Tensor`s must have the same first dimension.'
            ' (saw: {})'.format([elem.shape for elem in elems_flat])))
  elif not is_consistent_:
    raise ValueError(
        'Input `Tensor`s must have the same first dimension.'
        ' (saw: {})'.format([elem.shape for elem in elems_flat]))
  return elem_length, assertions
Esempio n. 15
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        x1, x2 = self._x1_x2()
        if (self._num_matmul_parts is None
                or prefer_static.equal(self._num_matmul_parts, 1)):
            return tf.matmul(self._kernel().matrix(x1, x2),
                             x,
                             adjoint_a=adjoint,
                             adjoint_b=adjoint_arg)

        if adjoint or adjoint_arg:
            raise NotImplementedError(
                '`adjoint`, `adjoint_arg` NYI when `num_matmul_parts` specified.'
            )

        return _chunked_matmul(kernel_fn=self.kernel_fn,
                               kernel_args=self.kernel_args,
                               x1=x1,
                               x2=x2,
                               x=x,
                               num_matmul_parts=self._num_matmul_parts,
                               operator_shape=self.shape_tensor())
Esempio n. 16
0
    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      results = prefer_static.cond(
          prefer_static.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([elem[0:1], result], axis=0)
                    for (elem, result) in zip(elems, results)]
      return list(map(_interleave, even_elems, odd_elems))
Esempio n. 17
0
 def _calculate_new_shape(self):
     # Try to get the old shape statically if available.
     original_shape = self._distribution.batch_shape
     if not tensorshape_util.is_fully_defined(original_shape):
         original_shape = self._distribution.batch_shape_tensor()
     # This is not a check for falseness, it's a check for exactly that shape.
     if original_shape == ():  # pylint: disable=g-explicit-bool-comparison
         # Force the size to be an integer, not a float, when the shape contains no
         # dtype information.
         original_size = 1
     else:
         original_size = ps.reduce_prod(original_shape)
     original_size = ps.cast(original_size, tf.int32)
     # Compute the new shape, filling in the `-1` dimension if present.
     new_shape = self._batch_shape_unexpanded
     implicit_dim_mask = ps.equal(new_shape, -1)
     size_implicit_dim = (original_size //
                          ps.maximum(1, -ps.reduce_prod(new_shape)))
     expanded_new_shape = ps.where(  # Assumes exactly one `-1`.
         implicit_dim_mask, size_implicit_dim, new_shape)
     # Return the original size on the side because one caller would otherwise
     # have to recompute it.
     return expanded_new_shape, original_size
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional seed for reproducible sampling.

    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = samplers.sanitize_seed(seed)
                proposal_seed, resample_seed = samplers.split_seed(seed)

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=proposal_seed)
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: tf.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=resample_seed)
                uniform_weights = tf.fill(
                    ps.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood),
                    seed=seed))
 def _has_nonzero_rank(self, override_shape):
     return prefer_static.logical_not(
         prefer_static.equal(prefer_static.rank_from_shape(override_shape),
                             self._zero))
Esempio n. 20
0
    def _forward_log_det_jacobian(self, x):
        # Let Y be a symmetric, positive definite matrix and write:
        #   Y = X X.T
        # where X is lower-triangular.
        #
        # Observe that,
        #   dY[i,j]/dX[a,b]
        #   = d/dX[a,b] { X[i,:] X[j,:] }
        #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
        #
        # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
        # symmetric and X is lower-triangular, we need vectors of dimension:
        #   d = p (p + 1) / 2
        # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
        #   k = { i (i + 1) / 2 + j   i>=j
        #       { undef               i<j
        # and assume zero-based indexes. When k is undef, the element is dropped.
        # Example:
        #           j      k
        #        0 1 2 3  /
        #    0 [ 0 . . . ]
        # i  1 [ 1 2 . . ]
        #    2 [ 3 4 5 . ]
        #    3 [ 6 7 8 9 ]
        # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
        # slight abuse: k(i,j)=undef means the element is dropped.)
        #
        # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
        # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
        # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
        # (1) j<=i<a thus i,j!=a.
        # (2) i=a>j  thus i,j!=a.
        #
        # Since the Jacobian is lower-triangular, we need only compute the product
        # of diagonal elements:
        #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
        #   = X[j,j] + I[i=j] X[i,j]
        #   = 2 X[j,j].
        # Since there is a 2 X[j,j] term for every lower-triangular element of X we
        # conclude:
        #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
        diag = tf.linalg.diag_part(x)

        # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
        # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
        # output is unchanged.
        diag = self._make_columnar(diag)

        with tf.control_dependencies(self._assertions(x)):
            # Create a vector equal to: [p, p-1, ..., 2, 1].
            if tf.compat.dimension_value(x.shape[-1]) is None:
                p_int = tf.shape(x)[-1]
                p_float = tf.cast(p_int, dtype=x.dtype)
            else:
                p_int = tf.compat.dimension_value(x.shape[-1])
                p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int)
            exponents = tf.linspace(p_float, 1., p_int)

            sum_weighted_log_diag = tf.squeeze(tf.matmul(
                tf.math.log(diag), exponents[..., tf.newaxis]),
                                               axis=-1)
            fldj = p_float * np.log(2.) + sum_weighted_log_diag

            # We finally need to undo adding an extra column in non-scalar cases
            # where there is a single matrix as input.
            if tensorshape_util.rank(x.shape) is not None:
                if tensorshape_util.rank(x.shape) == 2:
                    fldj = tf.squeeze(fldj, axis=-1)
                return fldj

            shape = ps.shape(fldj)
            maybe_squeeze_shape = ps.concat([
                shape[:-1],
                distribution_util.pick_vector(ps.equal(
                    ps.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])
            ], 0)
            return tf.reshape(fldj, maybe_squeeze_shape)
def _is_scalar_from_shape_tensor(shape):
    """Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
    return prefer_static.equal(prefer_static.rank_from_shape(shape), 0)
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = tf.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            # We let TransformedDistribution access _graph_parents since this class
            # is more like a baseclass than derived.
            graph_parents=(
                distribution._graph_parents +  # pylint: disable=protected-access
                bijector.graph_parents),
            name=name)
Esempio n. 23
0
def _get_search_direction(state):
  """Computes the search direction to follow at the current state.

  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
  the most recent `m` correction pairs in position_deltas and gradient_deltas,
  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.

  Assuming these, the code below is an implementation of the L-BFGS two-loop
  recursion algorithm given by [Nocedal and Wright(2006)][1]:

  ```None
    q_direction = objective_gradient
    for i in reversed(range(m)):  # First loop.
      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
      q_direction = q_direction - alpha[i] * gradient_deltas[i]

    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
                              gradient_deltas[-1]^T * gradient_deltas[-1])
    r_direction = kth_inv_hessian_factor * I * q_direction

    for i in range(m):  # Second loop.
      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)

    return -r_direction  # Approximates - H_k * objective_gradient.
  ```

  Args:
    state: A `LBfgsOptimizerResults` tuple with the current state of the
      search procedure.

  Returns:
    A real `Tensor` of the same shape as the `state.position`. The direction
    along which to perform line search.
  """
  # The number of correction pairs that have been collected so far.
  num_elements = ps.minimum(
      state.num_iterations,  # TODO(b/162733947): Change loop state -> closure.
      ps.shape(state.position_deltas)[0])

  def _two_loop_algorithm():
    """L-BFGS two-loop algorithm."""
    # Correction pairs are always appended to the end, so only the latest
    # `num_elements` vectors have valid position/gradient deltas. Vectors
    # that haven't been computed yet are zero.
    position_deltas = state.position_deltas
    gradient_deltas = state.gradient_deltas

    # Pre-compute all `inv_rho[i]`s.
    inv_rhos = tf.reduce_sum(
        gradient_deltas * position_deltas, axis=-1)

    def first_loop(acc, args):
      _, q_direction = acc
      position_delta, gradient_delta, inv_rho = args
      alpha = tf.math.divide_no_nan(
          tf.reduce_sum(position_delta * q_direction, axis=-1), inv_rho)
      direction_delta = alpha[..., tf.newaxis] * gradient_delta
      return (alpha, q_direction - direction_delta)

    # Run first loop body computing and collecting `alpha[i]`s, while also
    # computing the updated `q_direction` at each step.
    zero = tf.zeros_like(inv_rhos[-num_elements])
    alphas, q_directions = tf.scan(
        first_loop, [position_deltas, gradient_deltas, inv_rhos],
        initializer=(zero, state.objective_gradient), reverse=True)

    # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
    # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
    gamma_k = inv_rhos[-1] / tf.reduce_sum(
        gradient_deltas[-1] * gradient_deltas[-1], axis=-1)
    r_direction = gamma_k[..., tf.newaxis] * q_directions[-num_elements]

    def second_loop(r_direction, args):
      alpha, position_delta, gradient_delta, inv_rho = args
      beta = tf.math.divide_no_nan(
          tf.reduce_sum(gradient_delta * r_direction, axis=-1), inv_rho)
      direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
      return r_direction + direction_delta

    # Finally, run second loop body computing the updated `r_direction` at each
    # step.
    r_directions = tf.scan(
        second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
        initializer=r_direction)
    return -r_directions[-1]

  return ps.cond(ps.equal(num_elements, 0),
                 lambda: -state.objective_gradient,
                 _two_loop_algorithm)
Esempio n. 24
0
  def _scan(level, elems):
    """Perform scan on `elems`."""
    elem_length = ps.shape(elems[0])[axis]

    # Apply `fn` to reduce adjacent pairs to a single entry.
    a = [slice_elem(elem, 0, -1, step=2) for elem in elems]
    b = [slice_elem(elem, 1, None, step=2) for elem in elems]
    reduced_elems = lowered_fn(a, b)

    def handle_base_case_elem_length_two():
      return [tf.concat([slice_elem(elem, 0, 1), reduced_elem], axis=axis)
              for (reduced_elem, elem) in zip(reduced_elems, elems)]

    def handle_base_case_elem_length_three():
      reduced_reduced_elems = lowered_fn(
          reduced_elems,
          [slice_elem(elem, 2, 3) for elem in elems])
      return [
          tf.concat([slice_elem(elem, 0, 1),  # pylint: disable=g-complex-comprehension
                     reduced_elem,
                     reduced_reduced_elem], axis=axis)
          for (reduced_reduced_elem, reduced_elem, elem)
          in zip(reduced_reduced_elems, reduced_elems, elems)]

    # Base case of recursion: assumes `elem_length` is 2 or 3.
    at_base_case = ps.logical_or(
        ps.equal(elem_length, 2),
        ps.equal(elem_length, 3))
    base_value = lambda: ps.cond(  # pylint: disable=g-long-lambda
        ps.equal(elem_length, 2),
        handle_base_case_elem_length_two,
        handle_base_case_elem_length_three)

    if level <= 0:
      return base_value()

    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn(
            [slice_elem(odd_elem, 0, -1) for odd_elem in odd_elems],
            [slice_elem(elem, 2, None, 2) for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [slice_elem(elem, 2, None, 2) for elem in elems])

      results = ps.cond(
          ps.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([slice_elem(elem, 0, 1), result], axis=axis)
                    for (elem, result) in zip(elems, results)]
      return list(map(lambda a, b: _interleave(a, b, axis=axis),
                      even_elems,
                      odd_elems))

    return ps.cond(at_base_case, base_value, recursive_case)
Esempio n. 25
0
def _is_odd_integer(x):
    return ps.equal(x, ps.round(x)) & ps.not_equal(2. * ps.floor(x / 2.), x)