Ejemplo n.º 1
0
  def test_selects_batch_members_from_list_of_arrays(self):
    # Shape of each array: [2, 3] = [batch_size, event_size]
    # This test verifies that is_accepted selects batch members, despite the
    # "usual" broadcasting being applied on the right first (event first).
    zeros_states = [np.zeros((2, 3))]
    ones_states = [np.ones((2, 3))]
    chosen = util.choose(
        tf.constant([True, False]),
        zeros_states,
        ones_states)
    chosen_ = self.evaluate(chosen)

    # Make sure outer list wasn't interpreted as a dimenion of an array.
    self.assertIsInstance(chosen_, list)
    expected_array = np.array([
        [0., 0., 0.],  # zeros_states selected for first batch
        [1., 1., 1.],  # ones_states selected for second
    ])
    expected = [expected_array]
    self.assertAllEqual(expected, chosen_)
Ejemplo n.º 2
0
 def _swap_tensor(x):
     return mcmc_util.choose(
         is_swap_accepted_mask,
         mcmc_util.index_remapping_gather(x, swaps), x)
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'simple_step_size_adaptation',
                                    'one_step')):
            # Set the step_size.
            inner_results = self.step_size_setter_fn(
                previous_kernel_results.inner_results,
                previous_kernel_results.new_step_size)

            # Step the inner kernel.
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results)

            # Get the new step size.
            log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results)
            log_target_accept_prob = tf.math.log(
                tf.cast(previous_kernel_results.target_accept_prob,
                        dtype=log_accept_prob.dtype))

            state_parts = tf.nest.flatten(current_state)
            step_size = self.step_size_getter_fn(new_inner_results)
            step_size_parts = tf.nest.flatten(step_size)
            log_accept_prob_rank = prefer_static.rank(log_accept_prob)

            new_step_size_parts = []
            for step_size_part, state_part in zip(step_size_parts,
                                                  state_parts):
                # Compute new step sizes for each step size part. If step size part has
                # smaller rank than the corresponding state part, then the difference is
                # averaged away in the log accept prob.
                #
                # Example:
                #
                # state_part has shape      [2, 3, 4, 5]
                # step_size_part has shape     [1, 4, 1]
                # log_accept_prob has shape [2, 3, 4]
                #
                # Since step size has 1 rank fewer than the state, we reduce away the
                # leading dimension of log_accept_prob to get a Tensor with shape [3,
                # 4]. Next, since log_accept_prob must broadcast into step_size_part on
                # the left, we reduce the dimensions where their shapes differ, to get a
                # Tensor with shape [1, 4], which now is compatible with the leading
                # dimensions of step_size_part.
                #
                # There is a subtlety here in that step_size_parts might be a length-1
                # list, which means that we'll be "structure-broadcasting" it for all
                # the state parts (see logic in, e.g., hmc.py). In this case we must
                # assume that that the lone step size provided broadcasts with the event
                # dims of each state part. This means that either step size has no
                # dimensions corresponding to chain dimensions, or all states are of the
                # same shape. For the former, we want to reduce over all chain
                # dimensions. For the later, we want to use the same logic as in the
                # non-structure-broadcasted case.
                #
                # It turns out we can compute the reduction dimensions for both cases
                # uniformly by taking the rank of any state part. This obviously works
                # in the second case (where all state ranks are the same). In the first
                # case, all state parts have the rank L + D_i + B, where L is the rank
                # of log_accept_prob, D_i is the non-shared dimensions amongst all
                # states, and B are the shared dimensions of all the states, which are
                # equal to the step size. When we subtract B, we will always get a
                # number >= L, which means we'll get the full reduction we want.
                num_reduce_dims = prefer_static.minimum(
                    log_accept_prob_rank,
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                # reduced_log_accept_prob must broadcast into step_size_part on the
                # left, so we do an additional reduction over dimensions where their
                # shapes differ.
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)

                one_plus_adaptation_rate = 1. + tf.cast(
                    previous_kernel_results.adaptation_rate,
                    dtype=step_size_part.dtype)
                new_step_size_part = mcmc_util.choose(
                    reduced_log_accept_prob > log_target_accept_prob,
                    step_size_part * one_plus_adaptation_rate,
                    step_size_part / one_plus_adaptation_rate)

                new_step_size_parts.append(
                    tf.where(
                        previous_kernel_results.step <
                        self.num_adaptation_steps, new_step_size_part,
                        step_size_part))
            new_step_size = tf.nest.pack_sequence_as(step_size,
                                                     new_step_size_parts)

            return new_state, previous_kernel_results._replace(
                inner_results=new_inner_results,
                step=1 + previous_kernel_results.step,
                new_step_size=new_step_size)
 def _swap(is_exchange_accepted, x, y):
     """Swap batches of x, y where accepted."""
     with tf.compat.v1.name_scope('swap_where_exchange_accepted'):
         new_x = mcmc_util.choose(is_exchange_accepted, y, x)
         new_y = mcmc_util.choose(is_exchange_accepted, x, y)
     return new_x, new_y
Ejemplo n.º 5
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        is_seeded = seed is not None
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        proposal_seed, acceptance_seed = samplers.split_seed(seed)

        with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')):
            # Take one inner step.
            inner_kwargs = dict(seed=proposal_seed) if is_seeded else {}
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results,
                **inner_kwargs)
            if mcmc_util.is_list_like(current_state):
                proposed_state = tf.nest.pack_sequence_as(
                    current_state, proposed_state)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                samplers.uniform(shape=prefer_static.shape(
                    proposed_results.target_log_prob),
                                 dtype=dtype_util.base_dtype(
                                     proposed_results.target_log_prob.dtype),
                                 seed=acceptance_seed))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    # We strip seeds when populating `accepted_results` because unlike
                    # other kernel result fields, seeds are not a per-chain value.
                    # Thus it is impossible to choose between a previously accepted
                    # seed value and a proposed seed, since said choice would need to
                    # be made on a per-chain basis.
                    mcmc_util.strip_seeds(proposed_results),
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
                seed=seed,
            )

            return next_state, kernel_results
Ejemplo n.º 6
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        with tf1.name_scope(name=mcmc_util.make_name(self.name, 'mh',
                                                     'one_step'),
                            values=[current_state, previous_kernel_results]):
            # Take one inner step.
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                tf.random.uniform(
                    shape=tf.shape(input=proposed_results.target_log_prob),
                    dtype=proposed_results.target_log_prob.dtype.base_dtype,
                    seed=self._seed_stream()))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    proposed_results,
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
            )

            return next_state, kernel_results
Ejemplo n.º 7
0
 def _do_flip(state, i):
     new_state = sampler._flip_feature(state,
                                       tf.gather(flip_idxs, i))
     return mcmc_util.choose(tf.gather(should_flip, i), new_state,
                             state)
  def one_step(self, current_state, previous_kernel_results, seed=None):
    with tf.name_scope(
        mcmc_util.make_name(self.name,
                            'gradient_based_trajectory_length_adaptation',
                            'one_step')):

      jitter_seed, inner_seed = samplers.split_seed(seed)

      dtype = previous_kernel_results.adaptation_rate.dtype
      current_state = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(x, dtype=dtype), current_state)
      step_f = tf.cast(previous_kernel_results.step, dtype)
      if self.use_halton_sequence_jitter:
        trajectory_jitter = _halton_sequence(step_f)
      else:
        trajectory_jitter = samplers.uniform((), seed=jitter_seed, dtype=dtype)

      jitter_amount = previous_kernel_results.jitter_amount
      trajectory_jitter = (
          trajectory_jitter * jitter_amount + (1. - jitter_amount))

      adapting = previous_kernel_results.step < self.num_adaptation_steps
      max_trajectory_length = tf.where(
          adapting, previous_kernel_results.max_trajectory_length,
          previous_kernel_results.averaged_max_trajectory_length)
      jittered_trajectory_length = (max_trajectory_length * trajectory_jitter)

      step_size = _ensure_step_size_is_scalar(
          self.step_size_getter_fn(previous_kernel_results), self.validate_args)
      num_leapfrog_steps = tf.cast(
          tf.maximum(
              tf.ones([], dtype),
              tf.math.ceil(jittered_trajectory_length / step_size)), tf.int32)

      previous_kernel_results_with_jitter = self.num_leapfrog_steps_setter_fn(
          previous_kernel_results, num_leapfrog_steps)

      new_state, new_inner_results = self.inner_kernel.one_step(
          current_state, previous_kernel_results_with_jitter.inner_results,
          inner_seed)

      proposed_state = self.proposed_state_getter_fn(new_inner_results)
      proposed_velocity = self.proposed_velocity_getter_fn(new_inner_results)
      accept_prob = tf.exp(self.log_accept_prob_getter_fn(new_inner_results))

      new_kernel_results = _update_trajectory_grad(
          previous_kernel_results_with_jitter,
          previous_state=current_state,
          proposed_state=proposed_state,
          proposed_velocity=proposed_velocity,
          trajectory_jitter=trajectory_jitter,
          accept_prob=accept_prob,
          step_size=step_size,
          criterion_fn=self.criterion_fn,
          max_leapfrog_steps=self.max_leapfrog_steps)

      # Undo the effect of adaptation if we're not in the burnin phase. We keep
      # the criterion, however, as that's a diagnostic. We also keep the
      # leapfrog steps setting, as that's an effect of jitter (and also doubles
      # as a diagnostic).
      criterion = new_kernel_results.criterion
      new_kernel_results = mcmc_util.choose(
          adapting, new_kernel_results, previous_kernel_results_with_jitter)

      new_kernel_results = new_kernel_results._replace(
          inner_results=new_inner_results,
          step=previous_kernel_results.step + 1,
          criterion=criterion)

      return new_state, new_kernel_results
Ejemplo n.º 9
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            inner_results = previous_kernel_results.inner_results

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)

            def update_running_variance():
                diags = [
                    variance_part.variance()
                    for variance_part in variance_parts
                ]
                new_state_parts = tf.nest.flatten(new_state)
                new_variance_parts = []
                for variance_part, diag, state_part in zip(
                        variance_parts, diags, new_state_parts):
                    # Compute new variance for each variance part, accounting for partial
                    # batching of the variance calculation across chains (ie, some, all,
                    # or none of the chains may share the estimated mass matrix).
                    #
                    # For example, say
                    #
                    # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                    # variance_part has shape          [4] + [5, 6]
                    # log_prob has shape         [2, 3, 4]
                    #
                    # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                    # matrices, each being shared across a [2, 3]-batch of chains. Note
                    # this division is inferred from the shapes of the state part, the
                    # log_prob, and the user-provided initial running variances.
                    #
                    # Until RunningVariance supports rank > 1 chunking, we need to flatten
                    # the states that go into updating the variance estimates. In the
                    # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                    # fed to `RunningVariance.update(state_part, axis=0)`, recording
                    # 6 new observations in the running variance calculation.
                    # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                    # the resulting momentum distribution will have batch shape of
                    # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                    state_rank = ps.rank(state_part)
                    variance_rank = ps.rank(diag)
                    num_reduce_dims = state_rank - variance_rank

                    state_part_shape = ps.shape(state_part)
                    # This reshape adds a 1 when reduce_dims==0, and collapses all the
                    # lead dimensions to a single one otherwise.
                    reshaped_state = ps.reshape(
                        state_part,
                        ps.concat([[
                            ps.reduce_prod(state_part_shape[:num_reduce_dims])
                        ], state_part_shape[num_reduce_dims:]],
                                  axis=0))

                    # The `axis=0` here removes the leading dimension we got from the
                    # reshape above, so the new_variance_parts have the correct shape
                    # again.
                    new_variance_parts.append(
                        variance_part.update(reshaped_state, axis=0))
                return new_variance_parts

            def update_momentum():
                diags = [
                    variance_part.variance()
                    for variance_part in new_variance_parts
                ]
                # Update the momentum.
                prev_momentum_distribution = self.momentum_distribution_getter_fn(
                    new_inner_results)
                new_momentum_distribution = (
                    preconditioning_utils.update_momentum_distribution(
                        prev_momentum_distribution, diags))
                updated_new_inner_results = self.momentum_distribution_setter_fn(
                    new_inner_results, new_momentum_distribution)
                return updated_new_inner_results

            step = previous_kernel_results.step + 1
            if self.num_estimation_steps is None:
                new_variance_parts = update_running_variance()
                new_inner_results = update_momentum()
            else:
                new_variance_parts = mcmc_util.choose(
                    step <= previous_kernel_results.num_estimation_steps,
                    update_running_variance(), variance_parts)
                new_inner_results = mcmc_util.choose(
                    tf.equal(step,
                             previous_kernel_results.num_estimation_steps),
                    update_momentum(), new_inner_results)
            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts,
                step=step)

            return new_state, new_kernel_results