Exemple #1
0
 def test_works_for_nested_namedtuple(self):
     Results = collections.namedtuple('Results', ['field1', 'inner'])  # pylint: disable=invalid-name
     InnerResults = collections.namedtuple('InnerResults',
                                           ['fieldA', 'fieldB'])  # pylint: disable=invalid-name
     accepted = Results(field1=np.int32([1, 3]),
                        inner=InnerResults(fieldA=np.float32([5, 7]),
                                           fieldB=[
                                               np.float32([9, 11]),
                                               np.float64([13, 15]),
                                           ]))
     rejected = Results(field1=np.int32([0, 2]),
                        inner=InnerResults(fieldA=np.float32([4, 6]),
                                           fieldB=[
                                               np.float32([8, 10]),
                                               np.float64([12, 14]),
                                           ]))
     chosen = choose(tf.constant([False, True]), accepted, rejected)
     chosen_ = self.evaluate(chosen)
     # Lhs should be 0,4,8,12 and rhs=lhs+3.
     expected = Results(field1=np.int32([0, 3]),
                        inner=InnerResults(fieldA=np.float32([4, 7]),
                                           fieldB=[
                                               np.float32([8, 11]),
                                               np.float64([12, 15]),
                                           ]))
     self.assertAllClose(expected, chosen_, atol=0., rtol=1e-5)
Exemple #2
0
 def test_works_for_nested_namedtuple(self):
   Results = collections.namedtuple('Results', ['field1', 'inner'])  # pylint: disable=invalid-name
   InnerResults = collections.namedtuple('InnerResults', ['fieldA', 'fieldB'])  # pylint: disable=invalid-name
   accepted = Results(
       field1=np.int32([1, 3]),
       inner=InnerResults(
           fieldA=np.float32([5, 7]),
           fieldB=[
               np.float32([9, 11]),
               np.float64([13, 15]),
           ]))
   rejected = Results(
       field1=np.int32([0, 2]),
       inner=InnerResults(
           fieldA=np.float32([4, 6]),
           fieldB=[
               np.float32([8, 10]),
               np.float64([12, 14]),
           ]))
   chosen = choose(
       tf.constant([False, True]),
       accepted,
       rejected)
   chosen_ = self.evaluate(chosen)
   # Lhs should be 0,4,8,12 and rhs=lhs+3.
   expected = Results(
       field1=np.int32([0, 3]),
       inner=InnerResults(
           fieldA=np.float32([4, 7]),
           fieldB=[
               np.float32([8, 11]),
               np.float64([12, 15]),
           ]))
   self.assertAllClose(expected, chosen_, atol=0., rtol=1e-5)
Exemple #3
0
    def test_selects_batch_members_from_list_of_arrays(self):
        # Shape of each array: [2, 3] = [batch_size, event_size]
        # This test verifies that is_accepted selects batch members, despite the
        # "usual" broadcasting being applied on the right first (event first).
        zeros_states = [np.zeros((2, 3))]
        ones_states = [np.ones((2, 3))]
        chosen = choose(tf.constant([True, False]), zeros_states, ones_states)
        chosen_ = self.evaluate(chosen)

        # Make sure outer list wasn't interpreted as a dimenion of an array.
        self.assertIsInstance(chosen_, list)
        expected_array = np.array([
            [0., 0., 0.],  # zeros_states selected for first batch
            [1., 1., 1.],  # ones_states selected for second
        ])
        expected = [expected_array]
        self.assertAllEqual(expected, chosen_)
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        # Take one inner step.
        [
            proposed_state,
            proposed_results,
        ] = self.inner_kernel.one_step(
            current_state, previous_kernel_results.accepted_results)

        if (not has_target_log_prob(proposed_results)
                or not has_target_log_prob(
                    previous_kernel_results.accepted_results)):
            raise ValueError('"target_log_prob" must be a member of '
                             '`inner_kernel` results.')

        # Compute log(acceptance_ratio).
        to_sum = [
            proposed_results.target_log_prob,
            -previous_kernel_results.accepted_results.target_log_prob
        ]
        try:
            to_sum.append(proposed_results.log_acceptance_correction)
        except AttributeError:
            warnings.warn(
                'Supplied inner `TransitionKernel` does not have a '
                '`log_acceptance_correction`. Assuming its value is `0.`')
        log_accept_ratio = mcmc_util.safe_sum(to_sum,
                                              name='compute_log_accept_ratio')

        # If proposed state reduces likelihood: randomly accept.
        # If proposed state increases likelihood: always accept.
        # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
        #       ==> log(u) < log_accept_ratio
        # Note:
        # - We mutate seed state so subsequent calls are not correlated.
        # - We mutate seed BEFORE using it just in case users supplied the
        #   same seed to the inner kernel.
        self._seed = distributions_util.gen_new_seed(
            self.seed, salt='metropolis_hastings_one_step')
        log_uniform = tf.log(
            tf.random_uniform(
                shape=tf.shape(proposed_results.target_log_prob),
                dtype=proposed_results.target_log_prob.dtype.base_dtype,
                seed=self.seed))
        is_accepted = log_uniform < log_accept_ratio

        independent_chain_ndims = distributions_util.prefer_static_rank(
            proposed_results.target_log_prob)

        next_state = mcmc_util.choose(is_accepted, proposed_state,
                                      current_state, independent_chain_ndims)

        accepted_results = type(proposed_results)(
            **dict([(fn,
                     mcmc_util.choose(
                         is_accepted, getattr(proposed_results, fn),
                         getattr(previous_kernel_results.accepted_results, fn),
                         independent_chain_ndims))
                    for fn in proposed_results._fields]))

        return [
            next_state,
            MetropolisHastingsKernelResults(
                accepted_results=accepted_results,
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
            )
        ]
Exemple #6
0
 def _swap(is_exchange_accepted, x, y):
     """Swap batches of x, y where accepted."""
     with tf.name_scope('swap_where_exchange_accepted'):
         new_x = mcmc_util.choose(is_exchange_accepted, y, x)
         new_y = mcmc_util.choose(is_exchange_accepted, x, y)
     return new_x, new_y
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results