def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'rwm', 'one_step'),
        values=[self.seed,
                current_state,
                previous_kernel_results.target_log_prob]):
      with tf.name_scope('initialize'):
        current_state_parts = (list(current_state)
                               if mcmc_util.is_list_like(current_state)
                               else [current_state])
        current_state_parts = [tf.convert_to_tensor(s, name='current_state')
                               for s in current_state_parts]

      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='rwm_kernel_proposal')
      new_state_fn = self.new_state_fn
      next_state_parts = new_state_fn(current_state_parts, self._seed_stream)
      # Compute `target_log_prob` so its available to MetropolisHastings.
      next_target_log_prob = self.target_log_prob_fn(*next_state_parts)

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          UncalibratedRandomWalkResults(
              log_acceptance_correction=tf.zeros(
                  shape=tf.shape(next_target_log_prob),
                  dtype=next_target_log_prob.dtype.base_dtype),
              target_log_prob=next_target_log_prob,
          ),
      ]
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'bootstrap_results'),
        values=[init_state]):
      pkr = self.inner_kernel.bootstrap_results(init_state)
      if not has_target_log_prob(pkr):
        raise ValueError(
            '"target_log_prob" must be a member of `inner_kernel` results.')
      x = pkr.target_log_prob
      return MetropolisHastingsKernelResults(
          accepted_results=pkr,
          is_accepted=tf.ones_like(x, dtype=tf.bool),
          log_accept_ratio=tf.zeros_like(x),
          proposed_state=init_state,
          proposed_results=pkr,
          extra=[],
      )
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'rwm',
                                                    'one_step'),
                           values=[
                               self.seed, current_state,
                               previous_kernel_results.target_log_prob
                           ]):
            with tf.name_scope('initialize'):
                current_state_parts = (list(current_state)
                                       if mcmc_util.is_list_like(current_state)
                                       else [current_state])
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            new_state_fn = self.new_state_fn
            next_state_parts = new_state_fn(current_state_parts,
                                            self._seed_stream())
            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros(
                        shape=tf.shape(next_target_log_prob),
                        dtype=next_target_log_prob.dtype.base_dtype),
                    target_log_prob=next_target_log_prob,
                ),
            ]
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'bootstrap_results'),
        values=[init_state]):
      pkr = self.inner_kernel.bootstrap_results(init_state)
      if not has_target_log_prob(pkr):
        raise ValueError(
            '"target_log_prob" must be a member of `inner_kernel` results.')
      x = pkr.target_log_prob
      return MetropolisHastingsKernelResults(
          accepted_results=pkr,
          is_accepted=tf.ones_like(x, dtype=tf.bool),
          log_accept_ratio=tf.zeros_like(x),
          proposed_state=init_state,
          proposed_results=pkr,
          extra=[],
      )
Exemple #5
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                 'bootstrap_results'),
                        values=[init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         init_state = [tf.convert_to_tensor(x) for x in init_state]
         init_target_log_prob = self.target_log_prob_fn(*init_state)
         init_grads_target_log_prob = tf.gradients(init_target_log_prob,
                                                   init_state)
         return UncalibratedHamiltonianMonteCarloKernelResults(
             log_acceptance_correction=tf.zeros_like(init_target_log_prob),
             target_log_prob=init_target_log_prob,
             grads_target_log_prob=init_grads_target_log_prob,
         )
Exemple #6
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice',
                                                 'bootstrap_results'),
                        values=[init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         init_state = [tf.convert_to_tensor(x) for x in init_state]
         direction = [tf.zeros_like(x) for x in init_state]
         init_target_log_prob = self.target_log_prob_fn(*init_state)  # pylint:disable=not-callable
         return SliceSamplerKernelResults(
             target_log_prob=init_target_log_prob,
             bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob),
                                       dtype=tf.bool),
             direction=direction,
             upper_bounds=tf.zeros_like(init_target_log_prob),
             lower_bounds=tf.zeros_like(init_target_log_prob))
Exemple #7
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        a state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                    'bootstrap_results'),
                           values=[init_state]):
            replica_results = [
                self.replica_kernels[i].bootstrap_results(init_state)
                for i in range(self.num_replica)
            ]

            init_state_parts = (list(init_state)
                                if mcmc_util.is_list_like(init_state) else
                                [init_state])
            replica_states = [[tf.identity(s) for s in init_state_parts]
                              for i in range(self.num_replica)]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(init_state) else x[0]

            replica_states = [maybe_flatten(s) for s in replica_states]
            next_replica_idx = tf.range(self.num_replica)
            [
                exchange_proposed,
                exchange_proposed_n,
            ] = self.exchange_proposed_fn(self.num_replica,
                                          seed=self._seed_stream)
            exchange_proposed = tf.zeros_like(exchange_proposed)
            exchange_proposed_n = tf.zeros_like(exchange_proposed_n)
            return ReplicaExchangeMCKernelResults(
                replica_states=replica_states,
                replica_results=replica_results,
                next_replica_idx=next_replica_idx,
                exchange_proposed=exchange_proposed,
                exchange_proposed_n=exchange_proposed_n,
                sampled_replica_states=replica_states,
                sampled_replica_results=replica_results,
            )
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'slice', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     init_state = [tf.convert_to_tensor(x) for x in init_state]
     direction = [tf.zeros_like(x) for x in init_state]
     init_target_log_prob = self.target_log_prob_fn(*init_state)  # pylint:disable=not-callable
     return SliceSamplerKernelResults(
         target_log_prob=init_target_log_prob,
         bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob),
                                   dtype=tf.bool),
         direction=direction,
         upper_bounds=tf.zeros_like(init_target_log_prob),
         lower_bounds=tf.zeros_like(init_target_log_prob)
     )
Exemple #9
0
    def one_step(self, current_state, previous_kernel_results):
        """Runs one iteration of the Transformed Kernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s
        representing the current state(s) of the Markov chain(s),
        _after_ application of `bijector.forward`. The first `r`
        dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`. The
        `inner_kernel.one_step` does not actually use `current_state`,
        rather it takes as input
        `previous_kernel_results.transformed_state` (because
        `TransformedTransitionKernel` creates a copy of the input
        inner_kernel with a modified `target_log_prob_fn` which
        internally applies the `bijector.forward`).
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """
        with tf.compat.v1.name_scope(name=make_name(self.name,
                                                    'transformed_kernel',
                                                    'one_step'),
                                     values=[previous_kernel_results]):
            transformed_next_state, kernel_results = self._inner_kernel.one_step(
                previous_kernel_results.transformed_state,
                previous_kernel_results.inner_results)
            transformed_next_state_parts = (
                transformed_next_state if is_list_like(transformed_next_state)
                else [transformed_next_state])
            next_state_parts = self._forward_transform(
                transformed_next_state_parts)
            next_state = (next_state_parts
                          if is_list_like(transformed_next_state) else
                          next_state_parts[0])
            kernel_results = TransformedTransitionKernelResults(
                transformed_state=transformed_next_state,
                inner_results=kernel_results)
            return next_state, kernel_results
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        a state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'),
        values=[init_state]):
      replica_results = [self.replica_kernels[i].bootstrap_results(init_state)
                         for i in range(self.num_replica)]

      init_state_parts = (list(init_state)
                          if mcmc_util.is_list_like(init_state)
                          else [init_state])
      replica_states = [[tf.identity(s) for s in init_state_parts]
                        for i in range(self.num_replica)]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(init_state) else x[0]
      replica_states = [maybe_flatten(s) for s in replica_states]
      next_replica_idx = tf.range(self.num_replica)
      [
          exchange_proposed,
          exchange_proposed_n,
      ] = self.exchange_proposed_fn(self.num_replica, seed=self._seed_stream)
      exchange_proposed = tf.zeros_like(exchange_proposed)
      exchange_proposed_n = tf.zeros_like(exchange_proposed_n)
      return ReplicaExchangeMCKernelResults(
          replica_states=replica_states,
          replica_results=replica_results,
          next_replica_idx=next_replica_idx,
          exchange_proposed=exchange_proposed,
          exchange_proposed_n=exchange_proposed_n,
          sampled_replica_states=replica_states,
          sampled_replica_results=replica_results,
      )
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of the Transformed Kernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s
        representing the current state(s) of the Markov chain(s),
        _after_ application of `bijector.forward`. The first `r`
        dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`. The
        `inner_kernel.one_step` does not actually use `current_state`,
        rather it takes as input
        `previous_kernel_results.transformed_state` (because
        `TransformedTransitionKernel` creates a copy of the input
        inner_kernel with a modified `target_log_prob_fn` which
        internally applies the `bijector.forward`).
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """
    with tf.name_scope(
        name=make_name(self.name, 'transformed_kernel', 'one_step'),
        values=[previous_kernel_results]):
      transformed_next_state, kernel_results = self._inner_kernel.one_step(
          previous_kernel_results.transformed_state,
          previous_kernel_results.inner_results)
      transformed_next_state_parts = (
          transformed_next_state
          if is_list_like(transformed_next_state) else [transformed_next_state])
      next_state_parts = self._forward_transform(transformed_next_state_parts)
      next_state = (
          next_state_parts
          if is_list_like(transformed_next_state) else next_state_parts[0])
      kernel_results = TransformedTransitionKernelResults(
          transformed_state=transformed_next_state,
          inner_results=kernel_results)
      return next_state, kernel_results
Exemple #12
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     if self.state_gradients_are_stopped:
       init_state = [tf.stop_gradient(x) for x in init_state]
     else:
       init_state = [tf.convert_to_tensor(x) for x in init_state]
     [
         init_target_log_prob,
         init_grads_target_log_prob,
     ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state)
     return UncalibratedHamiltonianMonteCarloKernelResults(
         log_acceptance_correction=tf.zeros_like(init_target_log_prob),
         target_log_prob=init_target_log_prob,
         grads_target_log_prob=init_grads_target_log_prob,
     )
Exemple #13
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                 'bootstrap_results'),
                        values=[init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         if self.state_gradients_are_stopped:
             init_state = [tf.stop_gradient(x) for x in init_state]
         else:
             init_state = [
                 tf.convert_to_tensor(value=x) for x in init_state
             ]
         [
             init_target_log_prob,
             init_grads_target_log_prob,
         ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn,
                                               init_state)
         if self._store_parameters_in_results:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=tf.nest.map_structure(
                     lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                         value=x,
                         dtype=init_target_log_prob.dtype,
                         name='step_size'),
                     self.step_size),
                 num_leapfrog_steps=tf.convert_to_tensor(
                     value=self.num_leapfrog_steps,
                     dtype=tf.int64,
                     name='num_leapfrog_steps'))
         else:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=[],
                 num_leapfrog_steps=[])
Exemple #14
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala',
                                                    'bootstrap_results'),
                           values=[init_state]):
            init_state_parts = (list(init_state)
                                if mcmc_util.is_list_like(init_state) else
                                [init_state])

            init_state_parts = [
                tf.convert_to_tensor(x) for x in init_state_parts
            ]
            init_volatility = self.volatility_fn(*init_state_parts)  # pylint: disable=not-callable

            [
                _,  # state_parts
                _,  # step_sizes
                init_target_log_prob,
                init_grads_target_log_prob,
                init_volatility,
                init_grads_volatility,
                init_diffusion_drift,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              state=init_state_parts,
                              step_size=self.step_size,
                              volatility=init_volatility,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(init_state) else x[0]

            return UncalibratedLangevinKernelResults(
                log_acceptance_correction=tf.zeros_like(init_target_log_prob),
                target_log_prob=init_target_log_prob,
                grads_target_log_prob=init_grads_target_log_prob,
                volatility=maybe_flatten(init_volatility),
                grads_volatility=init_grads_volatility,
                diffusion_drift=init_diffusion_drift)
Exemple #15
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                    'bootstrap_results'),
                           values=[init_state]):
            replica_results = [
                self.replica_kernels[i].bootstrap_results(init_state)
                for i in range(self.num_replica)
            ]

            init_state_parts = (list(init_state)
                                if mcmc_util.is_list_like(init_state) else
                                [init_state])

            # Convert all states parts to tensor...
            replica_states = [[
                tf.convert_to_tensor(s) for s in init_state_parts
            ] for i in range(self.num_replica)]

            if not mcmc_util.is_list_like(init_state):
                replica_states = [s[0] for s in replica_states]

            return ReplicaExchangeMCKernelResults(
                replica_states=replica_states,
                replica_results=replica_results,
                sampled_replica_states=replica_states,
                sampled_replica_results=replica_results,
            )
Exemple #16
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                    'one_step'),
                           values=[
                               self.step_size, self.num_leapfrog_steps,
                               current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob
                           ]):
            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                self.step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random_normal(shape=tf.shape(x),
                                     dtype=x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            def _leapfrog_one_step(*args):
                """Closure representing computation done during each leapfrog step."""
                return _leapfrog_integrator_one_step(
                    target_log_prob_fn=self.target_log_prob_fn,
                    independent_chain_ndims=independent_chain_ndims,
                    step_sizes=step_sizes,
                    current_momentum_parts=args[0],
                    current_state_parts=args[1],
                    current_target_log_prob=args[2],
                    current_target_log_prob_grad_parts=args[3],
                    state_gradients_are_stopped=self.
                    state_gradients_are_stopped)

            # Do leapfrog integration.
            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = tf.while_loop(
                cond=lambda i, *args: i < self.num_leapfrog_steps,
                body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)
                                                     ),
                loop_vars=[
                    tf.zeros([], tf.int32, name='iter'),
                    current_momentum_parts,
                    current_state_parts,
                    current_target_log_prob,
                    current_target_log_prob_grad_parts,
                ])[1:]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedHamiltonianMonteCarloKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_momentum_parts, next_momentum_parts,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_target_log_prob_grad_parts,
                ),
            ]
    def bootstrap_results(self, init_state=None, transformed_init_state=None):
        """Returns an object with the same type as returned by `one_step`.

    Unlike other `TransitionKernel`s,
    `TransformedTransitionKernel.bootstrap_results` has the option of
    initializing the `TransformedTransitionKernelResults` from either an initial
    state, eg, requiring computing `bijector.inverse(init_state)`, or
    directly from `transformed_init_state`, i.e., a `Tensor` or list
    of `Tensor`s which is interpretted as the `bijector.inverse`
    transformed state.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the a
        state(s) of the Markov chain(s). Must specify `init_state` or
        `transformed_init_state` but not both.
      transformed_init_state: `Tensor` or Python `list` of `Tensor`s
        representing the a state(s) of the Markov chain(s). Must specify
        `init_state` or `transformed_init_state` but not both.

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".

    #### Examples

    To use `transformed_init_state` in context of
    `tfp.mcmc.sample_chain`, you need to explicitly pass the
    `previous_kernel_results`, e.g.,

    ```python
    transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
    init_state = ...        # Doesnt matter.
    transformed_init_state = ... # Does matter.
    results, _ = tfp.mcmc.sample_chain(
        num_results=...,
        current_state=init_state,
        previous_kernel_results=transformed_kernel.bootstrap_results(
            transformed_init_state=transformed_init_state),
        kernel=transformed_kernel)
    ```
    """
        if (init_state is None) == (transformed_init_state is None):
            raise ValueError('Must specify exactly one of `init_state` '
                             'or `transformed_init_state`.')
        with tf.name_scope(name=make_name(self.name, 'transformed_kernel',
                                          'bootstrap_results'),
                           values=[init_state, transformed_init_state]):
            if transformed_init_state is None:
                init_state_parts = (init_state if is_list_like(init_state) else
                                    [init_state])
                transformed_init_state_parts = self._inverse_transform(
                    init_state_parts)
                transformed_init_state = (transformed_init_state_parts
                                          if is_list_like(init_state) else
                                          transformed_init_state_parts[0])
            else:
                if is_list_like(transformed_init_state):
                    transformed_init_state = [
                        tf.convert_to_tensor(value=s,
                                             name='transformed_init_state')
                        for s in transformed_init_state
                    ]
                else:
                    transformed_init_state = tf.convert_to_tensor(
                        value=transformed_init_state,
                        name='transformed_init_state')
            kernel_results = TransformedTransitionKernelResults(
                transformed_state=transformed_init_state,
                inner_results=self._inner_kernel.bootstrap_results(
                    transformed_init_state))
            return kernel_results
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s of fully defined
        static shape representing the current state(s) of the Markov chain(s).
        The first `r` dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      ValueError: if `current_state` does not have a fully defined static shape.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'slice', 'one_step'),
        values=[self.step_size, self.max_doublings, self._seed_stream,
                current_state,
                previous_kernel_results.target_log_prob]):
      with tf.name_scope('initialize'):
        [
            current_state_parts,
            step_sizes,
            current_target_log_prob
        ] = _prepare_args(
            self.target_log_prob_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            maybe_expand=True)

        max_doublings = tf.convert_to_tensor(
            self.max_doublings,
            dtype=tf.int32,
            name='max_doublings')

      independent_chain_ndims = distributions_util.prefer_static_rank(
          current_target_log_prob)

      [
          next_state_parts,
          next_target_log_prob,
          bounds_satisfied,
          direction,
          upper_bounds,
          lower_bounds
      ] = _sample_next(
          self.target_log_prob_fn,
          current_state_parts,
          step_sizes,
          max_doublings,
          current_target_log_prob,
          independent_chain_ndims,
          seed=self._seed_stream()
      )

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          SliceSamplerKernelResults(
              target_log_prob=next_target_log_prob,
              bounds_satisfied=bounds_satisfied,
              direction=direction,
              upper_bounds=upper_bounds,
              lower_bounds=lower_bounds
          ),
      ]
Exemple #19
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala',
                                                    'one_step'),
                           values=[
                               self.step_size, current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob,
                               previous_kernel_results.volatility,
                               previous_kernel_results.diffusion_drift
                           ]):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                random_draw_parts = []
                for s in current_state_parts:
                    random_draw_parts.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self._seed_stream()))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts, next_state_parts,
                current_volatility_parts, next_volatility_parts,
                current_drift_parts, next_drift_parts, step_size_parts,
                independent_chain_ndims)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                self.compute_acceptance,
                lambda: log_acceptance_correction_compute,
                lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts),
            ]
Exemple #20
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'hmc', 'one_step'),
        values=[self.step_size,
                self.num_leapfrog_steps,
                current_state,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob]):
      if self._store_parameters_in_results:
        step_size = previous_kernel_results.step_size
        num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
      else:
        step_size = self.step_size
        num_leapfrog_steps = self.num_leapfrog_steps

      [
          current_state_parts,
          step_sizes,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          step_size,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      current_momentum_parts = []
      for x in current_state_parts:
        current_momentum_parts.append(
            tf.random.normal(
                shape=tf.shape(input=x),
                dtype=x.dtype.base_dtype,
                seed=self._seed_stream()))

      def _leapfrog_one_step(*args):
        """Closure representing computation done during each leapfrog step."""
        return _leapfrog_integrator_one_step(
            target_log_prob_fn=self.target_log_prob_fn,
            independent_chain_ndims=independent_chain_ndims,
            step_sizes=step_sizes,
            current_momentum_parts=args[0],
            current_state_parts=args[1],
            current_target_log_prob=args[2],
            current_target_log_prob_grad_parts=args[3],
            state_gradients_are_stopped=self.state_gradients_are_stopped)

      num_leapfrog_steps = tf.convert_to_tensor(
          value=self.num_leapfrog_steps,
          dtype=tf.int64,
          name='num_leapfrog_steps')

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,

      ] = tf.while_loop(
          cond=lambda i, *args: i < num_leapfrog_steps,
          body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)),
          loop_vars=[
              tf.zeros([], tf.int64, name='iter'),
              current_momentum_parts,
              current_state_parts,
              current_target_log_prob,
              current_target_log_prob_grad_parts
          ])[1:]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      new_kernel_results = previous_kernel_results._replace(
          log_acceptance_correction=_compute_log_acceptance_correction(
              current_momentum_parts, next_momentum_parts,
              independent_chain_ndims),
          target_log_prob=next_target_log_prob,
          grads_target_log_prob=next_target_log_prob_grad_parts,
      )

      return maybe_flatten(next_state_parts), new_kernel_results
Exemple #21
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                    'one_step'),
                           values=[current_state, previous_kernel_results]):
            sampled_replica_states, sampled_replica_results = zip(*[
                rk.one_step(previous_kernel_results.replica_states[i],
                            previous_kernel_results.replica_results[i])
                for i, rk in enumerate(self.replica_kernels)
            ])
            sampled_replica_states = list(sampled_replica_states)
            sampled_replica_results = list(sampled_replica_results)

            sampled_replica_results_modified = [
                srr._replace(target_log_prob=srr.target_log_prob /
                             self.inverse_temperatures[i])
                if 'target_log_prob' in srr._fields else srr._replace(
                    accepted_results=srr.accepted_results._replace(
                        target_log_prob=srr.accepted_results.target_log_prob /
                        self.inverse_temperatures[i]))
                for i, srr in enumerate(sampled_replica_results)
            ]

            sampled_replica_ratios = [
                srr.target_log_prob if 'target_log_prob' in srr._fields else
                srr.accepted_results.target_log_prob
                for i, srr in enumerate(sampled_replica_results_modified)
            ]
            sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1)

            next_replica_idx = tf.range(self.num_replica)
            self._seed_stream = distributions_util.gen_new_seed(
                self._seed_stream, salt='replica_exchange_one_step')
            exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn(
                self.num_replica, seed=self._seed_stream)
            i = tf.constant(0)

            def cond(i, next_replica_idx):  # pylint: disable=unused-argument
                return tf.less(i, exchange_proposed_n)

            def body(i, next_replica_idx):
                """`tf.while_loop` body."""
                ratio = (sampled_replica_ratios[next_replica_idx[
                    exchange_proposed[i, 0]]] - sampled_replica_ratios[
                        next_replica_idx[exchange_proposed[i, 1]]])
                ratio *= (self.inverse_temperatures[exchange_proposed[i, 1]] -
                          self.inverse_temperatures[exchange_proposed[i, 0]])
                self._seed_stream = distributions_util.gen_new_seed(
                    self._seed_stream, salt='replica_exchange_one_step')
                log_uniform = tf.log(
                    tf.random_uniform(shape=tf.shape(ratio),
                                      dtype=ratio.dtype.base_dtype,
                                      seed=self._seed_stream))
                exchange = log_uniform < ratio
                exchange_op = tf.sparse_to_dense(
                    [exchange_proposed[i, 0], exchange_proposed[i, 1]],
                    [self.num_replica], [
                        next_replica_idx[exchange_proposed[i, 1]] -
                        next_replica_idx[exchange_proposed[i, 0]],
                        next_replica_idx[exchange_proposed[i, 0]] -
                        next_replica_idx[exchange_proposed[i, 1]]
                    ])
                next_replica_idx = tf.cond(
                    exchange, lambda: next_replica_idx + exchange_op,
                    lambda: next_replica_idx)
                return [i + 1, next_replica_idx]

            next_replica_idx = tf.while_loop(cond,
                                             body,
                                             loop_vars=[i,
                                                        next_replica_idx])[1]

            def _prep(list_):
                return list(
                    tf.case(
                        {
                            tf.equal(next_replica_idx[i], j): _stateful_lambda(
                                list_[j])
                            for j in range(self.num_replica)
                        },
                        exclusive=True) for i in range(self.num_replica))

            next_replica_states = _prep(sampled_replica_states)
            next_replica_results = _prep(sampled_replica_results_modified)

            next_replica_results = [
                nrr._replace(target_log_prob=nrr.target_log_prob *
                             self.inverse_temperatures[i])
                if 'target_log_prob' in nrr._fields else nrr._replace(
                    accepted_results=nrr.accepted_results._replace(
                        target_log_prob=nrr.accepted_results.target_log_prob *
                        self.inverse_temperatures[i]))
                for i, nrr in enumerate(next_replica_results)
            ]

            next_state = tf.identity(next_replica_states[0])
            kernel_results = ReplicaExchangeMCKernelResults(
                replica_states=next_replica_states,
                replica_results=next_replica_results,
                next_replica_idx=next_replica_idx,
                exchange_proposed=exchange_proposed,
                exchange_proposed_n=exchange_proposed_n,
                sampled_replica_states=sampled_replica_states,
                sampled_replica_results=sampled_replica_results,
            )

            return next_state, kernel_results
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'remc', 'one_step'),
        values=[current_state, previous_kernel_results]):
      sampled_replica_states, sampled_replica_results = zip(*[
          rk.one_step(previous_kernel_results.replica_states[i],
                      previous_kernel_results.replica_results[i])
          for i, rk in enumerate(self.replica_kernels)])
      sampled_replica_states = list(sampled_replica_states)
      sampled_replica_results = list(sampled_replica_results)

      sampled_replica_results_modified = [
          srr._replace(target_log_prob=srr.target_log_prob /
                       self.inverse_temperatures[i])
          if 'target_log_prob' in srr._fields
          else srr._replace(accepted_results=srr.accepted_results._replace(
              target_log_prob=srr.accepted_results.target_log_prob /
              self.inverse_temperatures[i]))
          for i, srr in enumerate(sampled_replica_results)
      ]

      sampled_replica_ratios = [
          srr.target_log_prob if 'target_log_prob' in srr._fields
          else srr.accepted_results.target_log_prob
          for i, srr in enumerate(sampled_replica_results_modified)]
      sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1)

      next_replica_idx = tf.range(self.num_replica)
      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='replica_exchange_one_step')
      exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn(
          self.num_replica, seed=self._seed_stream)
      i = tf.constant(0)

      def cond(i, next_replica_idx):  # pylint: disable=unused-argument
        return tf.less(i, exchange_proposed_n)

      def body(i, next_replica_idx):
        """`tf.while_loop` body."""
        ratio = (
            sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 0]]]
            - sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 1]]])
        ratio *= (
            self.inverse_temperatures[exchange_proposed[i, 1]]
            - self.inverse_temperatures[exchange_proposed[i, 0]])
        self._seed_stream = distributions_util.gen_new_seed(
            self._seed_stream, salt='replica_exchange_one_step')
        log_uniform = tf.log(tf.random_uniform(
            shape=tf.shape(ratio),
            dtype=ratio.dtype.base_dtype,
            seed=self._seed_stream))
        exchange = log_uniform < ratio
        exchange_op = tf.sparse_to_dense(
            [exchange_proposed[i, 0], exchange_proposed[i, 1]],
            [self.num_replica],
            [next_replica_idx[exchange_proposed[i, 1]] -
             next_replica_idx[exchange_proposed[i, 0]],
             next_replica_idx[exchange_proposed[i, 0]] -
             next_replica_idx[exchange_proposed[i, 1]]])
        next_replica_idx = tf.cond(exchange,
                                   lambda: next_replica_idx + exchange_op,
                                   lambda: next_replica_idx)
        return [i + 1, next_replica_idx]

      next_replica_idx = tf.while_loop(
          cond, body, loop_vars=[i, next_replica_idx])[1]

      def _prep(list_):
        return list(
            tf.case({tf.equal(next_replica_idx[i], j):
                     _stateful_lambda(list_[j])
                     for j in range(self.num_replica)}, exclusive=True)
            for i in range(self.num_replica))
      next_replica_states = _prep(sampled_replica_states)
      next_replica_results = _prep(sampled_replica_results_modified)

      next_replica_results = [
          nrr._replace(target_log_prob=nrr.target_log_prob *
                       self.inverse_temperatures[i])
          if 'target_log_prob' in nrr._fields
          else nrr._replace(accepted_results=nrr.accepted_results._replace(
              target_log_prob=nrr.accepted_results.target_log_prob *
              self.inverse_temperatures[i]))
          for i, nrr in enumerate(next_replica_results)
      ]

      next_state = tf.identity(next_replica_states[0])
      kernel_results = ReplicaExchangeMCKernelResults(
          replica_states=next_replica_states,
          replica_results=next_replica_results,
          next_replica_idx=next_replica_idx,
          exchange_proposed=exchange_proposed,
          exchange_proposed_n=exchange_proposed_n,
          sampled_replica_states=sampled_replica_states,
          sampled_replica_results=sampled_replica_results,
      )

      return next_state, kernel_results
Exemple #23
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                    'one_step'),
                           values=[
                               self.step_size, self.num_leapfrog_steps,
                               self._seed_stream, current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob
                           ]):
            with tf.name_scope('initialize'):
                [
                    current_state_parts,
                    step_sizes,
                    current_target_log_prob,
                    current_grads_target_log_prob,
                ] = _prepare_args(
                    self.target_log_prob_fn,
                    current_state,
                    self.step_size,
                    previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    maybe_expand=True)

                current_momentums = []
                for s in current_state_parts:
                    # Note:
                    # - We mutate seed state so subsequent calls are not correlated.
                    # - We mutate seed BEFORE using it just in case users supplied the
                    #   same seed to an outer kernel, e.g., `MetropolisHastings`.
                    self._seed_stream = distributions_util.gen_new_seed(
                        self._seed_stream, salt='hmc_kernel_momentums')
                    current_momentums.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self._seed_stream))

                num_leapfrog_steps = tf.convert_to_tensor(
                    self.num_leapfrog_steps,
                    dtype=tf.int32,
                    name='num_leapfrog_steps')

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            [
                next_momentums,
                next_state_parts,
                next_target_log_prob,
                next_grads_target_log_prob,
            ] = _leapfrog_integrator(current_momentums,
                                     self.target_log_prob_fn,
                                     current_state_parts, step_sizes,
                                     num_leapfrog_steps,
                                     current_target_log_prob,
                                     current_grads_target_log_prob)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedHamiltonianMonteCarloKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_momentums, next_momentums,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                ),
            ]
Exemple #24
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'hmc', 'one_step'),
        values=[self.step_size,
                self.num_leapfrog_steps,
                current_state,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob]):
      [
          current_state_parts,
          step_sizes,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          self.step_size,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      current_momentum_parts = []
      for x in current_state_parts:
        current_momentum_parts.append(tf.random_normal(
            shape=tf.shape(x),
            dtype=x.dtype.base_dtype,
            seed=self._seed_stream()))

      def _leapfrog_one_step(*args):
        """Closure representing computation done during each leapfrog step."""
        return _leapfrog_integrator_one_step(
            target_log_prob_fn=self.target_log_prob_fn,
            independent_chain_ndims=independent_chain_ndims,
            step_sizes=step_sizes,
            current_momentum_parts=args[0],
            current_state_parts=args[1],
            current_target_log_prob=args[2],
            current_target_log_prob_grad_parts=args[3],
            state_gradients_are_stopped=self.state_gradients_are_stopped)

      num_leapfrog_steps = tf.convert_to_tensor(
          self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps')

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,

      ] = tf.while_loop(
          cond=lambda i, *args: i < num_leapfrog_steps,
          body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)),
          loop_vars=[
              tf.zeros([], tf.int64, name='iter'),
              current_momentum_parts,
              current_state_parts,
              current_target_log_prob,
              current_target_log_prob_grad_parts
          ])[1:]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          UncalibratedHamiltonianMonteCarloKernelResults(
              log_acceptance_correction=_compute_log_acceptance_correction(
                  current_momentum_parts,
                  next_momentum_parts,
                  independent_chain_ndims),
              target_log_prob=next_target_log_prob,
              grads_target_log_prob=next_target_log_prob_grad_parts,
          ),
      ]
  def bootstrap_results(self, init_state=None, transformed_init_state=None):
    """Returns an object with the same type as returned by `one_step`.

    Unlike other `TransitionKernel`s,
    `TransformedTransitionKernel.bootstrap_results` has the option of
    initializing the `TransformedTransitionKernelResults` from either an initial
    state, eg, requiring computing `bijector.inverse(init_state)`, or
    directly from `transformed_init_state`, i.e., a `Tensor` or list
    of `Tensor`s which is interpretted as the `bijector.inverse`
    transformed state.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the a
        state(s) of the Markov chain(s). Must specify `init_state` or
        `transformed_init_state` but not both.
      transformed_init_state: `Tensor` or Python `list` of `Tensor`s
        representing the a state(s) of the Markov chain(s). Must specify
        `init_state` or `transformed_init_state` but not both.

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".

    #### Examples

    To use `transformed_init_state` in context of
    `tfp.mcmc.sample_chain`, you need to explicitly pass the
    `previous_kernel_results`, e.g.,

    ```python
    transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
    init_state = ...        # Doesnt matter.
    transformed_init_state = ... # Does matter.
    results, _ = tfp.mcmc.sample_chain(
        num_results=...,
        current_state=init_state,
        previous_kernel_results=transformed_kernel.bootstrap_results(
            transformed_init_state=transformed_init_state),
        kernel=transformed_kernel)
    ```
    """
    if (init_state is None) == (transformed_init_state is None):
      raise ValueError('Must specify exactly one of `init_state` '
                       'or `transformed_init_state`.')
    with tf.name_scope(
        name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'),
        values=[init_state, transformed_init_state]):
      if transformed_init_state is None:
        init_state_parts = (init_state if is_list_like(init_state)
                            else [init_state])
        transformed_init_state_parts = self._inverse_transform(init_state_parts)
        transformed_init_state = (
            transformed_init_state_parts
            if is_list_like(init_state) else transformed_init_state_parts[0])
      else:
        if is_list_like(transformed_init_state):
          transformed_init_state = [
              tf.convert_to_tensor(s, name='transformed_init_state')
              for s in transformed_init_state
          ]
        else:
          transformed_init_state = tf.convert_to_tensor(
              transformed_init_state, name='transformed_init_state')
      kernel_results = TransformedTransitionKernelResults(
          transformed_state=transformed_init_state,
          inner_results=self._inner_kernel.bootstrap_results(
              transformed_init_state))
      return kernel_results
Exemple #26
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # Key difficulty:  The type of exchanges differs from one call to the
        # next...even the number of exchanges can differ.
        # As a result, exchanges must happen dynamically, in while loops.
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                    'one_step'),
                           values=[current_state, previous_kernel_results]):

            # Each replica does `one_step` to get pre-exchange states/KernelResults.
            sampled_replica_states, sampled_replica_results = zip(*[
                rk.one_step(previous_kernel_results.replica_states[i],
                            previous_kernel_results.replica_results[i])
                for i, rk in enumerate(self.replica_kernels)
            ])
            sampled_replica_states = list(sampled_replica_states)
            sampled_replica_results = list(sampled_replica_results)

            states_are_lists = mcmc_util.is_list_like(
                sampled_replica_states[0])

            if not states_are_lists:
                sampled_replica_states = [[s] for s in sampled_replica_states]
            num_state_parts = len(sampled_replica_states[0])

            dtype = sampled_replica_states[0][0].dtype

            # Must put states into TensorArrays.  Why?  We will read/write states
            # dynamically with Tensor index `i`, and you cannot do this with lists.
            # old_states[k][i] is Tensor of (old) state part k, for replica i.
            # The `k` will be known statically, and `i` is a Tensor.
            old_states = [
                tf.TensorArray(
                    dtype,
                    size=self.num_replica,
                    dynamic_size=False,
                    clear_after_read=False,
                    tensor_array_name='old_states',
                    # State part k has same shape, regardless of replica.  So use 0.
                    element_shape=sampled_replica_states[0][k].shape)
                for k in range(num_state_parts)
            ]
            for k in range(num_state_parts):
                for i in range(self.num_replica):
                    old_states[k] = old_states[k].write(
                        i, sampled_replica_states[i][k])

            exchange_proposed = self.exchange_proposed_fn(
                self.num_replica, seed=self._seed_stream())
            exchange_proposed_n = tf.shape(exchange_proposed)[0]

            exchanged_states = self._get_exchanged_states(
                old_states, exchange_proposed, exchange_proposed_n,
                sampled_replica_states, sampled_replica_results)

            no_exchange_proposed, _ = tf.setdiff1d(
                tf.range(self.num_replica), tf.reshape(exchange_proposed,
                                                       [-1]))

            exchanged_states = self._insert_old_states_where_no_exchange_was_proposed(
                no_exchange_proposed, old_states, exchanged_states)

            next_replica_states = []
            for i in range(self.num_replica):
                next_replica_states_i = []
                for k in range(num_state_parts):
                    next_replica_states_i.append(exchanged_states[k].read(i))
                next_replica_states.append(next_replica_states_i)

            if not states_are_lists:
                next_replica_states = [s[0] for s in next_replica_states]
                sampled_replica_states = [s[0] for s in sampled_replica_states]

            # Now that states are/aren't exchanged, bootstrap next kernel_results.
            # The viewpoint is that after each exchange, we are starting anew.
            next_replica_results = [
                rk.bootstrap_results(state)
                for rk, state in zip(self.replica_kernels, next_replica_states)
            ]

            next_state = next_replica_states[
                0]  # Replica 0 is the returned state(s).

            kernel_results = ReplicaExchangeMCKernelResults(
                replica_states=next_replica_states,
                replica_results=next_replica_results,
                sampled_replica_states=sampled_replica_states,
                sampled_replica_results=sampled_replica_results,
            )

            return next_state, kernel_results
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results
Exemple #29
0
    def one_step(self, current_state, previous_kernel_results):
        """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice',
                                                    'one_step'),
                           values=[
                               self.step_size, self.max_doublings,
                               self._seed_stream, current_state,
                               previous_kernel_results.target_log_prob
                           ]):
            with tf.name_scope('initialize'):
                [current_state_parts, step_sizes, current_target_log_prob
                 ] = _prepare_args(self.target_log_prob_fn,
                                   current_state,
                                   self.step_size,
                                   previous_kernel_results.target_log_prob,
                                   maybe_expand=True)

                max_doublings = tf.convert_to_tensor(self.max_doublings,
                                                     dtype=tf.int32,
                                                     name='max_doublings')

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            [
                next_state_parts, next_target_log_prob, bounds_satisfied,
                direction, upper_bounds, lower_bounds
            ] = _sample_next(self.target_log_prob_fn,
                             current_state_parts,
                             step_sizes,
                             max_doublings,
                             current_target_log_prob,
                             independent_chain_ndims,
                             seed=self._seed_stream())

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                SliceSamplerKernelResults(target_log_prob=next_target_log_prob,
                                          bounds_satisfied=bounds_satisfied,
                                          direction=direction,
                                          upper_bounds=upper_bounds,
                                          lower_bounds=lower_bounds),
            ]