def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None, but if `store_parameters_in_results`
    # is true, then `momentum_distribution` defaults to an empty list
    if momentum_distribution is None or isinstance(momentum_distribution,
                                                   list):
        batch_rank = ps.rank(target_log_prob)

        def _batched_isotropic_normal_like(state_part):
            event_ndims = ps.rank(state_part) - batch_rank
            return independent.Independent(
                normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
                reinterpreted_batch_ndims=event_ndims)

        momentum_distribution = jds.JointDistributionSequential([
            _batched_isotropic_normal_like(state_part)
            for state_part in state_parts
        ])

    # The momentum will get "maybe listified" to zip with the state parts,
    # and this step makes sure that the momentum distribution will have the
    # same "maybe listified" underlying shape.
    if not mcmc_util.is_list_like(momentum_distribution.dtype):
        momentum_distribution = jds.JointDistributionSequential(
            [momentum_distribution])

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
Exemple #2
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        target_log_prob,
        grads_target_log_prob,
    ]
Exemple #3
0
 def _swap_then_retemper(x):
     x, is_multipart = mcmc_util.prepare_state_parts(x)
     it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0])
     x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x]
     if not is_multipart:
         x = x[0]
     return x
Exemple #4
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False,
                  experimental_shard_axis_names=None):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None
    if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts,
            ps.shape(target_log_prob),
            shard_axis_names=experimental_shard_axis_names)
    momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
        momentum_distribution, ps.shape(target_log_prob))

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
Exemple #5
0
def _prepare_step_size(step_size, dtype, n_state_parts):
  step_sizes, _ = mcmc_util.prepare_state_parts(
      step_size, dtype=dtype, name='step_size')
  if len(step_sizes) == 1:
    step_sizes *= n_state_parts
  if n_state_parts != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  return step_sizes
Exemple #6
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(
             mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')):
         init_state, _ = mcmc_util.prepare_state_parts(init_state)
         if self.state_gradients_are_stopped:
             init_state = [tf.stop_gradient(x) for x in init_state]
         [
             init_target_log_prob,
             init_grads_target_log_prob,
         ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn,
                                               init_state)
         if self._store_parameters_in_results:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 initial_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 final_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 # TODO(b/142590314): Try to use the following code once we commit to
                 # a tensorization policy.
                 # step_size=mcmc_util.prepare_state_parts(
                 #    self.step_size,
                 #    dtype=init_target_log_prob.dtype,
                 #    name='step_size')[0],
                 step_size=tf.nest.map_structure(
                     lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                         x,
                         dtype=init_target_log_prob.dtype,
                         name='step_size'),
                     self.step_size),
                 num_leapfrog_steps=tf.convert_to_tensor(
                     self.num_leapfrog_steps,
                     dtype=tf.int32,
                     name='num_leapfrog_steps'))
         else:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 initial_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 final_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 step_size=[],
                 num_leapfrog_steps=[])
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      state_parts, _ = mcmc_util.prepare_state_parts(init_state,
                                                     name='current_state')
      current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
          self.target_log_prob_fn, state_parts)
      # Confirm that the step size is compatible with the state parts.
      _ = _prepare_step_size(
          self.step_size, current_target_log_prob.dtype, len(init_state))
      momentum_distribution = self.momentum_distribution
      if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts, ps.shape(current_target_log_prob),
            shard_axis_names=self.experimental_shard_axis_names)
      momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
          momentum_distribution, ps.shape(current_target_log_prob))
      momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed())

      return PreconditionedNUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          step_size=tf.nest.map_structure(
              lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                  x,
                  dtype=current_target_log_prob.dtype,
                  name='step_size'),
              self.step_size),
          log_accept_ratio=tf.zeros_like(
              current_target_log_prob, name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(
              current_target_log_prob,
              dtype=TREE_COUNT_DTYPE,
              name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='is_accepted'),
          reach_max_depth=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='reach_max_depth'),
          has_divergence=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='has_divergence'),
          energy=compute_hamiltonian(current_target_log_prob, momentum_parts,
                                     momentum_distribution),
          momentum_distribution=momentum_distribution,
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Exemple #8
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'phmc', 'bootstrap_results')):
            result = super(UncalibratedPreconditionedHamiltonianMonteCarlo,
                           self).bootstrap_results(init_state)

            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            target_log_prob = self.target_log_prob_fn(*state_parts)
            if (not self._store_parameters_in_results
                    or self.momentum_distribution is None):
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(target_log_prob))
            else:
                momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                    self.momentum_distribution, ps.shape(target_log_prob))
            result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults(
                **result._asdict(),  # pylint: disable=protected-access
                momentum_distribution=momentum_distribution)
        return result
Exemple #9
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = ps.convert_to_shape_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
                    '(`seed`) argument. `TransitionKernel` instances now receive seeds '
                    'via `one_step`.')

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = bu.left_justified_broadcast_to(tf.range(num_replica),
                                                   replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=_set_swapped_fields_to_nan(
                    replica_results),
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
                potential_energy=tf.zeros_like(
                    pre_swap_replica_target_log_prob),
            )
Exemple #10
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we fall back to the previous behavior and warn.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The second (`seed`) argument to `ReplicaExchangeMC`s '
                    '`make_kernel_fn` is deprecated. `TransitionKernel` instances now '
                    'receive seeds via `bootstrap_results` and `one_step`. This '
                    'fallback may become an error 2020-09-20.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            post_swap_replica_results = _make_post_swap_replica_results(
                replica_results,
                inverse_temperatures,
                inverse_temperatures,
                is_swap_accepted[0],
                lambda x: x,
            )

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
            )
Exemple #11
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = prefer_static.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            replica_states = [
                tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                    x,
                    prefer_static.concat(
                        [replica_shape, prefer_static.shape(x)], axis=0),
                    name='replica_states') for x in init_state
            ]

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())
            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            post_swap_replica_results = _make_post_swap_replica_results(
                replica_results,
                inverse_temperatures,
                inverse_temperatures,
                is_swap_accepted[0],
                lambda x: x,
            )

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
            )
Exemple #12
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            # Padding the step_size so it is compatable with the states
            step_size = self.step_size
            if len(step_size) == 1:
                step_size = step_size * len(init_state)
            if len(step_size) != len(init_state):
                raise ValueError(
                    'Expected either one step size or {} (size of '
                    '`init_state`), but found {}'.format(
                        len(init_state), len(step_size)))
            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
                self.target_log_prob_fn, state_parts)
            momentum_distribution = self.momentum_distribution
            if momentum_distribution is None:
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(current_target_log_prob))
            momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                momentum_distribution, ps.shape(current_target_log_prob))
            momentum_parts = momentum_distribution.sample()

            def _init(shape_and_dtype):
                """Allocate TensorArray for storing state and velocity."""
                return [  # pylint: disable=g-complex-comprehension
                    ps.zeros(ps.concat([[max(self._write_instruction) + 1], s],
                                       axis=0),
                             dtype=d) for (s, d) in shape_and_dtype
                ]

            get_shapes_and_dtypes = lambda x: [
                (ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                for x_ in x
            ]
            velocity_state_memory = VelocityStateSwap(
                velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)),
                state_swap=_init(get_shapes_and_dtypes(init_state)))

            return PreconditionedNUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                velocity_state_memory=velocity_state_memory,
                step_size=step_size,
                log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                               name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                              dtype=TREE_COUNT_DTYPE,
                                              name='leapfrogs_taken'),
                is_accepted=tf.zeros_like(current_target_log_prob,
                                          dtype=tf.bool,
                                          name='is_accepted'),
                reach_max_depth=tf.zeros_like(current_target_log_prob,
                                              dtype=tf.bool,
                                              name='reach_max_depth'),
                has_divergence=tf.zeros_like(current_target_log_prob,
                                             dtype=tf.bool,
                                             name='has_divergence'),
                energy=compute_hamiltonian(current_target_log_prob,
                                           momentum_parts,
                                           momentum_distribution),
                momentum_distribution=momentum_distribution,
                # Allow room for one_step's seed.
                seed=samplers.zeros_seed(),
            )
Exemple #13
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob)
  step_sizes, _ = mcmc_util.prepare_state_parts(
      step_size, dtype=target_log_prob.dtype, name='step_size')

  # Default momentum distribution is None, but if `store_parameters_in_results`
  # is true, then `momentum_distribution` defaults to DefaultStandardNormal().
  if (momentum_distribution is None or
      isinstance(momentum_distribution, DefaultStandardNormal)):
    batch_rank = ps.rank(target_log_prob)
    def _batched_isotropic_normal_like(state_part):
      return sample.Sample(
          normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
          ps.shape(state_part)[batch_rank:])

    momentum_distribution = jds.JointDistributionSequential(
        [_batched_isotropic_normal_like(state_part)
         for state_part in state_parts])

  # The momentum will get "maybe listified" to zip with the state parts,
  # and this step makes sure that the momentum distribution will have the
  # same "maybe listified" underlying shape.
  if not mcmc_util.is_list_like(momentum_distribution.dtype):
    momentum_distribution = jds.JointDistributionSequential(
        [momentum_distribution])

  # If all underlying distributions are independent, we can offer some help.
  # This code will also trigger for the output of the two blocks above.
  if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
      not any(callable(dist_fn) for dist_fn in momentum_distribution.model)):
    batch_shape = ps.shape(target_log_prob)
    momentum_distribution = momentum_distribution.copy(model=[
        batch_broadcast.BatchBroadcast(md, to_shape=batch_shape)
        for md in momentum_distribution.model
    ])

  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      momentum_distribution,
      target_log_prob,
      grads_target_log_prob,
  ]