Ejemplo n.º 1
0
def _prepare_args(target_log_prob_fn, state, step_size,
                  target_log_prob=None, maybe_expand=False,
                  description='target_log_prob'):
  """Processes input args to meet list-like assumptions."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  state_parts = [tf.convert_to_tensor(s, name='current_state')
                 for s in state_parts]
  # Verifies that the input static shape is fully defined.
  state_shapes_defined = [s.shape.is_fully_defined() for s in state_parts]
  if not np.all(state_shapes_defined):
    raise ValueError('All static shapes must be fully defined.')
  target_log_prob = _maybe_call_fn(
      target_log_prob_fn,
      state_parts,
      target_log_prob,
      description)
  step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size)
                else [step_size])
  step_sizes = [
      tf.convert_to_tensor(
          s, name='step_size', dtype=target_log_prob.dtype)
      for s in step_sizes]
  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      target_log_prob
  ]
  def _fn(state_parts, seed):
    """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
    with tf.name_scope(name, 'random_walk_normal_fn',
                       values=[state_parts, scale, seed]):
      scales = scale if mcmc_util.is_list_like(scale) else [scale]
      if len(scales) == 1:
        scales *= len(state_parts)
      if len(state_parts) != len(scales):
        raise ValueError('`scale` must broadcast with `state_parts`.')
      seed_stream = distributions.SeedStream(seed, salt='RandomWalkNormalFn')
      next_state_parts = [
          tf.random_normal(
              mean=state_part,
              stddev=scale_part,
              shape=tf.shape(state_part),
              dtype=state_part.dtype.base_dtype,
              seed=seed_stream()
          ) for scale_part, state_part in zip(scales, state_parts)]

      return next_state_parts
Ejemplo n.º 3
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'rwm', 'one_step'),
        values=[self.seed,
                current_state,
                previous_kernel_results.target_log_prob]):
      with tf.name_scope('initialize'):
        current_state_parts = (list(current_state)
                               if mcmc_util.is_list_like(current_state)
                               else [current_state])
        current_state_parts = [tf.convert_to_tensor(s, name='current_state')
                               for s in current_state_parts]

      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='rwm_kernel_proposal')
      new_state_fn = self.new_state_fn
      next_state_parts = new_state_fn(current_state_parts, self._seed_stream)
      # Compute `target_log_prob` so its available to MetropolisHastings.
      next_target_log_prob = self.target_log_prob_fn(*next_state_parts)

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          UncalibratedRandomWalkResults(
              log_acceptance_correction=tf.zeros(
                  shape=tf.shape(next_target_log_prob),
                  dtype=next_target_log_prob.dtype.base_dtype),
              target_log_prob=next_target_log_prob,
          ),
      ]
Ejemplo n.º 4
0
 def build_assign_op():
   if mcmc_util.is_list_like(step_size_var):
     return [
         ss.assign_add(ss * tf.cast(adjustment, ss.dtype))
         for ss in step_size_var
     ]
   return step_size_var.assign_add(
       step_size_var * tf.cast(adjustment, step_size_var.dtype))
Ejemplo n.º 5
0
def inverse_transform_fn(bijector):
  """Makes a function which applies a list of Bijectors' `inverse`s."""
  if not is_list_like(bijector):
    bijector = [bijector]
  def fn(state_parts):
    return [b.inverse(sp)
            for b, sp in zip(bijector, state_parts)]
  return fn
Ejemplo n.º 6
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(self.name, 'rwm_bootstrap_results', [init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     init_state = [tf.convert_to_tensor(x) for x in init_state]
     init_target_log_prob = self.target_log_prob_fn(*init_state)
     return UncalibratedRandomWalkResults(
         log_acceptance_correction=tf.zeros_like(init_target_log_prob),
         target_log_prob=init_target_log_prob)
Ejemplo n.º 7
0
def forward_transform_fn(bijector):
  """Makes a function which applies a list of Bijectors' `forward`s."""
  if not is_list_like(bijector):
    bijector = [bijector]

  def fn(transformed_state_parts):
    return [b.forward(sp) for b, sp in zip(bijector, transformed_state_parts)]

  return fn
Ejemplo n.º 8
0
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of the Transformed Kernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s
        representing the current state(s) of the Markov chain(s),
        _after_ application of `bijector.forward`. The first `r`
        dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`. The
        `inner_kernel.one_step` does not actually use `current_state`,
        rather it takes as input
        `previous_kernel_results.transformed_state` (because
        `TransformedTransitionKernel` creates a copy of the input
        inner_kernel with a modified `target_log_prob_fn` which
        internally applies the `bijector.forward`).
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """
    with tf.name_scope(
        name=make_name(self.name, 'transformed_kernel', 'one_step'),
        values=[previous_kernel_results]):
      transformed_next_state, kernel_results = self._inner_kernel.one_step(
          previous_kernel_results.transformed_state,
          previous_kernel_results.inner_results)
      transformed_next_state_parts = (
          transformed_next_state
          if is_list_like(transformed_next_state) else [transformed_next_state])
      next_state_parts = self._forward_transform(transformed_next_state_parts)
      next_state = (
          next_state_parts
          if is_list_like(transformed_next_state) else next_state_parts[0])
      kernel_results = TransformedTransitionKernelResults(
          transformed_state=transformed_next_state,
          inner_results=kernel_results)
      return next_state, kernel_results
Ejemplo n.º 9
0
def forward_log_det_jacobian_fn(bijector):
  """Makes a function which applies a list of Bijectors' `log_det_jacobian`s."""
  if not is_list_like(bijector):
    bijector = [bijector]

  def fn(transformed_state_parts, event_ndims):
    return sum([
        b.forward_log_det_jacobian(sp, event_ndims=e)
        for b, e, sp in zip(bijector, event_ndims, transformed_state_parts)
    ])

  return fn
Ejemplo n.º 10
0
def _maybe_call_fn(fn,
                   fn_arg_list,
                   fn_result=None,
                   description='target_log_prob'):
  """Helper which computes `fn_result` if needed."""
  fn_arg_list = (list(fn_arg_list) if mcmc_util.is_list_like(fn_arg_list)
                 else [fn_arg_list])
  if fn_result is None:
    fn_result = fn(*fn_arg_list)
  if not fn_result.dtype.is_floating:
    raise TypeError('`{}` must be a `Tensor` with `float` `dtype`.'.format(
        description))
  return fn_result
Ejemplo n.º 11
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  state_parts = [tf.convert_to_tensor(s, name='current_state')
                 for s in state_parts]
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn,
      state_parts,
      target_log_prob,
      grads_target_log_prob)
  step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size)
                else [step_size])
  step_sizes = [
      tf.convert_to_tensor(
          s, name='step_size', dtype=target_log_prob.dtype)
      for s in step_sizes]
  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      target_log_prob,
      grads_target_log_prob,
  ]
Ejemplo n.º 12
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'slice', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     init_state = [tf.convert_to_tensor(x) for x in init_state]
     direction = [tf.zeros_like(x) for x in init_state]
     init_target_log_prob = self.target_log_prob_fn(*init_state)  # pylint:disable=not-callable
     return SliceSamplerKernelResults(
         target_log_prob=init_target_log_prob,
         bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob),
                                   dtype=tf.bool),
         direction=direction,
         upper_bounds=tf.zeros_like(init_target_log_prob),
         lower_bounds=tf.zeros_like(init_target_log_prob)
     )
 def _loop_body(iter_, ais_weights, current_state, kernel_results):
   """Closure which implements `tf.while_loop` body."""
   x = (current_state if mcmc_util.is_list_like(current_state)
        else [current_state])
   proposal_log_prob = proposal_log_prob_fn(*x)
   target_log_prob = target_log_prob_fn(*x)
   ais_weights += ((target_log_prob - proposal_log_prob) /
                   tf.cast(num_steps, ais_weights.dtype))
   kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
   next_state, inner_results = kernel.one_step(
       current_state, kernel_results.inner_results)
   kernel_results = AISResults(
       proposal_log_prob=proposal_log_prob,
       target_log_prob=target_log_prob,
       inner_results=inner_results,
   )
   return [iter_ + 1, ais_weights, next_state, kernel_results]
Ejemplo n.º 14
0
  def step_size_simple_update_fn(step_size_var, kernel_results):
    """Updates (list of) `step_size` using a standard adaptive MCMC procedure.

    Args:
      step_size_var: (List of) `tf.Variable`s representing the per `state_part`
        HMC `step_size`.
      kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from most recent call to `one_step`.

    Returns:
      step_size_assign: (List of) `Tensor`(s) representing updated
        `step_size_var`(s).
    """

    if kernel_results is None:
      if mcmc_util.is_list_like(step_size_var):
        return [tf.identity(ss) for ss in step_size_var]
      return tf.identity(step_size_var)
    log_n = tf.log(tf.cast(tf.size(kernel_results.log_accept_ratio),
                           kernel_results.log_accept_ratio.dtype))
    log_mean_accept_ratio = tf.reduce_logsumexp(
        tf.minimum(kernel_results.log_accept_ratio, 0.)) - log_n
    adjustment = tf.where(
        log_mean_accept_ratio < tf.cast(
            tf.log(target_rate), log_mean_accept_ratio.dtype),
        -decrement_multiplier / (1. + decrement_multiplier),
        increment_multiplier)

    def build_assign_op():
      if mcmc_util.is_list_like(step_size_var):
        return [
            ss.assign_add(ss * tf.cast(adjustment, ss.dtype))
            for ss in step_size_var
        ]
      return step_size_var.assign_add(
          step_size_var * tf.cast(adjustment, step_size_var.dtype))

    if num_adaptation_steps is None:
      return build_assign_op()
    else:
      with tf.control_dependencies([step_counter.assign_add(1)]):
        return tf.cond(step_counter < num_adaptation_steps,
                       build_assign_op,
                       lambda: step_size_var)
Ejemplo n.º 15
0
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        a state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'),
        values=[init_state]):
      replica_results = [self.replica_kernels[i].bootstrap_results(init_state)
                         for i in range(self.num_replica)]

      init_state_parts = (list(init_state)
                          if mcmc_util.is_list_like(init_state)
                          else [init_state])
      replica_states = [[tf.identity(s) for s in init_state_parts]
                        for i in range(self.num_replica)]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(init_state) else x[0]
      replica_states = [maybe_flatten(s) for s in replica_states]
      next_replica_idx = tf.range(self.num_replica)
      [
          exchange_proposed,
          exchange_proposed_n,
      ] = self.exchange_proposed_fn(self.num_replica, seed=self._seed_stream)
      exchange_proposed = tf.zeros_like(exchange_proposed)
      exchange_proposed_n = tf.zeros_like(exchange_proposed_n)
      return ReplicaExchangeMCKernelResults(
          replica_states=replica_states,
          replica_results=replica_results,
          next_replica_idx=next_replica_idx,
          exchange_proposed=exchange_proposed,
          exchange_proposed_n=exchange_proposed_n,
          sampled_replica_states=replica_states,
          sampled_replica_results=replica_results,
      )
Ejemplo n.º 16
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     if self.state_gradients_are_stopped:
       init_state = [tf.stop_gradient(x) for x in init_state]
     else:
       init_state = [tf.convert_to_tensor(x) for x in init_state]
     [
         init_target_log_prob,
         init_grads_target_log_prob,
     ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state)
     return UncalibratedHamiltonianMonteCarloKernelResults(
         log_acceptance_correction=tf.zeros_like(init_target_log_prob),
         target_log_prob=init_target_log_prob,
         grads_target_log_prob=init_grads_target_log_prob,
     )
Ejemplo n.º 17
0
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of Hamiltonian Monte Carlo.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions index
        independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
    """
    previous_step_size_assign = (
        [] if self.step_size_update_fn is None
        else (previous_kernel_results.extra.step_size_assign
              if mcmc_util.is_list_like(
                  previous_kernel_results.extra.step_size_assign)
              else [previous_kernel_results.extra.step_size_assign]))

    with tf.control_dependencies(previous_step_size_assign):
      next_state, kernel_results = self._impl.one_step(
          current_state, previous_kernel_results)
      if self.step_size_update_fn is not None:
        step_size_assign = self.step_size_update_fn(  # pylint: disable=not-callable
            self.step_size, kernel_results)
        kernel_results = kernel_results._replace(
            extra=HamiltonianMonteCarloExtraKernelResults(
                step_size_assign=step_size_assign))
      return next_state, kernel_results
Ejemplo n.º 18
0
  def _fn(state_parts, seed):
    """Adds a uniform perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
    with tf.name_scope(name, 'random_walk_uniform_fn',
                       values=[state_parts, scale, seed]):
      scales = scale if mcmc_util.is_list_like(scale) else [scale]
      if len(scales) == 1:
        scales *= len(state_parts)
      if len(state_parts) != len(scales):
        raise ValueError('`scale` must broadcast with `state_parts`.')
      next_state_parts = []
      for scale_part, state_part in zip(scales, state_parts):
        # Mutate seed with each use.
        seed = distributions_util.gen_new_seed(
            seed, salt='random_walk_uniform_fn')
        next_state_parts.append(tf.random_uniform(
            minval=state_part - scale_part,
            maxval=state_part + scale_part,
            shape=tf.shape(state_part),
            dtype=state_part.dtype.base_dtype,
            seed=seed))
      return next_state_parts
 def one_step(current_state, previous_kernel_results):
   # Make next_state.
   if is_list_like(current_state):
     next_state = []
     for i, s in enumerate(current_state):
       next_state.append(tf.identity(s * dtype(i + 2),
                                     name='next_state'))
   else:
     next_state = tf.identity(2. * current_state,
                              name='next_state')
   # Make kernel_results.
   kernel_results = {}
   for fn in sorted(previous_kernel_results._fields):
     if fn == 'grads_target_log_prob':
       kernel_results['grads_target_log_prob'] = [
           tf.identity(0.5 * g, name='grad_target_log_prob')
           for g in previous_kernel_results.grads_target_log_prob]
     else:
       kernel_results[fn] = tf.identity(
           0.5 * getattr(previous_kernel_results, fn, None),
           name=fn)
   kernel_results = type(previous_kernel_results)(**kernel_results)
   # Done.
   return next_state, kernel_results
Ejemplo n.º 20
0
 def maybe_flatten(x):
     return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
Ejemplo n.º 21
0
 def maybe_flatten(x):
   return x if mcmc_util.is_list_like(current_state) else x[0]
Ejemplo n.º 22
0
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results
Ejemplo n.º 23
0
    def bootstrap_results(self, init_state=None, transformed_init_state=None):
        """Returns an object with the same type as returned by `one_step`.

    Unlike other `TransitionKernel`s,
    `TransformedTransitionKernel.bootstrap_results` has the option of
    initializing the `TransformedTransitionKernelResults` from either an initial
    state, eg, requiring computing `bijector.inverse(init_state)`, or
    directly from `transformed_init_state`, i.e., a `Tensor` or list
    of `Tensor`s which is interpretted as the `bijector.inverse`
    transformed state.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the a
        state(s) of the Markov chain(s). Must specify `init_state` or
        `transformed_init_state` but not both.
      transformed_init_state: `Tensor` or Python `list` of `Tensor`s
        representing the a state(s) of the Markov chain(s). Must specify
        `init_state` or `transformed_init_state` but not both.

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".

    #### Examples

    To use `transformed_init_state` in context of
    `tfp.mcmc.sample_chain`, you need to explicitly pass the
    `previous_kernel_results`, e.g.,

    ```python
    transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
    init_state = ...        # Doesnt matter.
    transformed_init_state = ... # Does matter.
    results, _ = tfp.mcmc.sample_chain(
        num_results=...,
        current_state=init_state,
        previous_kernel_results=transformed_kernel.bootstrap_results(
            transformed_init_state=transformed_init_state),
        kernel=transformed_kernel)
    ```
    """
        if (init_state is None) == (transformed_init_state is None):
            raise ValueError('Must specify exactly one of `init_state` '
                             'or `transformed_init_state`.')
        with tf.compat.v1.name_scope(
                name=make_name(self.name, 'transformed_kernel',
                               'bootstrap_results'),
                values=[init_state, transformed_init_state]):
            if transformed_init_state is None:
                init_state_parts = (init_state if is_list_like(init_state) else
                                    [init_state])
                transformed_init_state_parts = self._inverse_transform(
                    init_state_parts)
                transformed_init_state = (transformed_init_state_parts
                                          if is_list_like(init_state) else
                                          transformed_init_state_parts[0])
            else:
                if is_list_like(transformed_init_state):
                    transformed_init_state = [
                        tf.convert_to_tensor(value=s,
                                             name='transformed_init_state')
                        for s in transformed_init_state
                    ]
                else:
                    transformed_init_state = tf.convert_to_tensor(
                        value=transformed_init_state,
                        name='transformed_init_state')
            kernel_results = TransformedTransitionKernelResults(
                transformed_state=transformed_init_state,
                inner_results=self._inner_kernel.bootstrap_results(
                    transformed_init_state))
            return kernel_results
Ejemplo n.º 24
0
 def maybe_flatten(x):
   return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
Ejemplo n.º 25
0
def sample_chain(num_results,
                 current_state,
                 previous_kernel_results=None,
                 kernel=None,
                 num_burnin_steps=0,
                 num_steps_between_results=0,
                 parallel_iterations=10,
                 name=None):
    """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from an Markov chain at `current_state` and whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  This function can sample from multiple chains, in parallel. (Whether or not
  there are multiple chains is dictated by the `kernel`.)

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such "thinning"
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized (in calls to
  `sess.run`), and thus do not increase memory requirements.

  Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s
  `parallel_iterations=1`, otherwise results will not be reproducible.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
      `list` of `Tensor`s representing internal calculations made within the
      previous call to this function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "mcmc_sample_chain").

  Returns:
    next_states: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state` but with a prepended `num_results`-size dimension.
    kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
      `Tensor`s representing internal calculations made within this function.

  #### Examples

  ##### Sample from a diagonal-variance Gaussian.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  def make_likelihood(true_variances):
    return tfd.MultivariateNormalDiag(
        scale_diag=tf.sqrt(true_variances))

  dims = 10
  dtype = np.float32
  true_variances = tf.linspace(dtype(1), dtype(3), dims)
  likelihood = make_likelihood(true_variances)

  states, kernel_results = tfp.mcmc.sample_chain(
      num_results=1000,
      current_state=tf.zeros(dims),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(states, axis=0)
  sample_var = tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0)
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  for i=1..n:
    w[i] ~ Normal(0, eye(d))            # prior
    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
  ```

  where `F` denotes factors.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))

  # Setup data.
  num_weights = 10
  num_factors = 4
  num_chains = 100
  dtype = np.float32

  prior = make_prior(num_weights, dtype)
  weights = prior.sample(num_chains)
  factors = np.random.randn(num_factors, num_weights).astype(dtype)
  x = make_likelihood(weights, factors).sample(num_chains)

  def target_log_prob(w):
    # Target joint is: `f(w) = p(w, x | factors)`.
    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = tfp.mcmc.sample_chain(
      num_results=1000,
      current_state=tf.zeros([num_chains, dims], dtype),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=target_log_prob,
        step_size=0.1,
        num_leapfrog_steps=2),
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  ```

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
    if not kernel.is_calibrated:
        warnings.warn(
            "Supplied `TransitionKernel` is not calibrated. Markov "
            "chain may not converge to intended target distribution.")
    with tf.name_scope(
            name, "mcmc_sample_chain",
        [num_results, num_burnin_steps, num_steps_between_results]):
        num_results = tf.convert_to_tensor(num_results,
                                           dtype=tf.int32,
                                           name="num_results")
        num_burnin_steps = tf.convert_to_tensor(num_burnin_steps,
                                                dtype=tf.int64,
                                                name="num_burnin_steps")
        num_steps_between_results = tf.convert_to_tensor(
            num_steps_between_results,
            dtype=tf.int64,
            name="num_steps_between_results")

        if mcmc_util.is_list_like(current_state):
            current_state = [
                tf.convert_to_tensor(s, name="current_state")
                for s in current_state
            ]
        else:
            current_state = tf.convert_to_tensor(current_state,
                                                 name="current_state")

        def _scan_body(args_list, num_steps):
            """Closure which implements `tf.scan` body."""
            next_state, current_kernel_results = mcmc_util.smart_for_loop(
                loop_num_iter=num_steps,
                body_fn=kernel.one_step,
                initial_loop_vars=args_list,
                parallel_iterations=parallel_iterations)
            return [next_state, current_kernel_results]

        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        return tf.scan(
            fn=_scan_body,
            elems=tf.one_hot(indices=0,
                             depth=num_results,
                             on_value=1 + num_burnin_steps,
                             off_value=1 + num_steps_between_results,
                             dtype=tf.int64),  # num_steps
            initializer=[current_state, previous_kernel_results],
            parallel_iterations=parallel_iterations)
Ejemplo n.º 26
0
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    # Key difficulty:  The type of exchanges differs from one call to the
    # next...even the number of exchanges can differ.
    # As a result, exchanges must happen dynamically, in while loops.
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'remc', 'one_step'),
        values=[current_state, previous_kernel_results]):

      # Each replica does `one_step` to get pre-exchange states/KernelResults.
      sampled_replica_states, sampled_replica_results = zip(*[
          rk.one_step(previous_kernel_results.replica_states[i],
                      previous_kernel_results.replica_results[i])
          for i, rk in enumerate(self.replica_kernels)
      ])
      sampled_replica_states = list(sampled_replica_states)
      sampled_replica_results = list(sampled_replica_results)

      states_are_lists = mcmc_util.is_list_like(sampled_replica_states[0])

      if not states_are_lists:
        sampled_replica_states = [[s] for s in sampled_replica_states]
      num_state_parts = len(sampled_replica_states[0])

      dtype = sampled_replica_states[0][0].dtype

      # Must put states into TensorArrays.  Why?  We will read/write states
      # dynamically with Tensor index `i`, and you cannot do this with lists.
      # old_states[k][i] is Tensor of (old) state part k, for replica i.
      # The `k` will be known statically, and `i` is a Tensor.
      old_states = [
          tf.TensorArray(
              dtype,
              size=self.num_replica,
              dynamic_size=False,
              clear_after_read=False,
              tensor_array_name='old_states',
              # State part k has same shape, regardless of replica.  So use 0.
              element_shape=sampled_replica_states[0][k].shape)
          for k in range(num_state_parts)
      ]
      for k in range(num_state_parts):
        for i in range(self.num_replica):
          old_states[k] = old_states[k].write(i, sampled_replica_states[i][k])

      exchange_proposed = self.exchange_proposed_fn(
          self.num_replica, seed=self._seed_stream())
      exchange_proposed_n = tf.shape(exchange_proposed)[0]

      exchanged_states = self._get_exchanged_states(
          old_states, exchange_proposed, exchange_proposed_n,
          sampled_replica_states, sampled_replica_results)

      no_exchange_proposed, _ = tf.setdiff1d(
          tf.range(self.num_replica), tf.reshape(exchange_proposed, [-1]))

      exchanged_states = self._insert_old_states_where_no_exchange_was_proposed(
          no_exchange_proposed, old_states, exchanged_states)

      next_replica_states = []
      for i in range(self.num_replica):
        next_replica_states_i = []
        for k in range(num_state_parts):
          next_replica_states_i.append(exchanged_states[k].read(i))
        next_replica_states.append(next_replica_states_i)

      if not states_are_lists:
        next_replica_states = [s[0] for s in next_replica_states]
        sampled_replica_states = [s[0] for s in sampled_replica_states]

      # Now that states are/aren't exchanged, bootstrap next kernel_results.
      # The viewpoint is that after each exchange, we are starting anew.
      next_replica_results = [
          rk.bootstrap_results(state)
          for rk, state in zip(self.replica_kernels, next_replica_states)
      ]

      next_state = next_replica_states[0]  # Replica 0 is the returned state(s).

      kernel_results = ReplicaExchangeMCKernelResults(
          replica_states=next_replica_states,
          replica_results=next_replica_results,
          sampled_replica_states=sampled_replica_states,
          sampled_replica_results=sampled_replica_results,
      )

      return next_state, kernel_results
def sample_annealed_importance_chain(num_steps,
                                     proposal_log_prob_fn,
                                     target_log_prob_fn,
                                     current_state,
                                     make_kernel_fn,
                                     parallel_iterations=10,
                                     name=None):
    """Runs annealed importance sampling (AIS) to estimate normalizing constants.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial "proposal" distribution:

  `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

  and the target distribution:

  `exp(target_log_prob_fn(x) - target_log_normalizer)`,

  accumulating importance weights along the way. The product of these
  importance weights gives an unbiased estimate of the ratio of the
  normalizing constants of the initial distribution and the target
  distribution:

  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

  Note: When running in graph mode, `proposal_log_prob_fn` and
  `target_log_prob_fn` are called exactly three times (although this may be
  reduced to two times in the future).

  Args:
    num_steps: Integer number of Markov chain updates to run. More
      iterations means more expense, but smoother annealing between q
      and p, which in turn means exponentially lower variance for the
      normalizing constant estimator.
    proposal_log_prob_fn: Python callable that returns the log density of the
      initial distribution.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "sample_annealed_importance_chain").

  Returns:
    next_state: `Tensor` or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at the final iteration. Has same shape as
      input `current_state`.
    ais_weights: Tensor with the estimated weight(s). Has shape matching
      `target_log_prob_fn(current_state)`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  #### Examples

  ##### Estimate the normalizing constant of a log-gamma distribution.

  ```python
  tfd = tfp.distributions

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 20
  dtype = np.float32

  proposal = tfd.MultivatiateNormalDiag(
     loc=tf.zeros([dims], dtype=dtype))

  target = tfd.TransformedDistribution(
    distribution=tfd.Gamma(concentration=dtype(2),
                           rate=dtype(3)),
    bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()),
    event_shape=[dims])

  chains_state, ais_weights, kernels_results = (
      tfp.mcmc.sample_annealed_importance_chain(
          num_steps=1000,
          proposal_log_prob_fn=proposal.log_prob,
          target_log_prob_fn=target.log_prob,
          current_state=proposal.sample(num_chains),
          make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn,
            step_size=0.2,
            num_leapfrog_steps=2)))

  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                              - np.log(num_chains))
  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
  ```

  ##### Estimate marginal likelihood of a Bayesian regression model.

  ```python
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, x):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 10
  dtype = np.float32

  # Make training data.
  x = np.random.randn(num_chains, dims).astype(dtype)
  true_weights = np.random.randn(dims).astype(dtype)
  y = np.dot(x, true_weights) + np.random.randn(num_chains)

  # Setup model.
  prior = make_prior(dims, dtype)
  def target_log_prob_fn(weights):
    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

  proposal = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

  weight_samples, ais_weights, kernel_results = (
      tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target_log_prob_fn
        current_state=tf.zeros([num_chains, dims], dtype),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.1,
          num_leapfrog_steps=2)))
  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                             - np.log(num_chains))
  ```

  """
    with tf.compat.v1.name_scope(name, "sample_annealed_importance_chain",
                                 [num_steps, current_state]):
        num_steps = tf.convert_to_tensor(value=num_steps,
                                         dtype=tf.int32,
                                         name="num_steps")
        if mcmc_util.is_list_like(current_state):
            current_state = [
                tf.convert_to_tensor(value=s, name="current_state")
                for s in current_state
            ]
        else:
            current_state = tf.convert_to_tensor(value=current_state,
                                                 name="current_state")

        def _make_convex_combined_log_prob_fn(iter_):
            def _fn(*args):
                p = tf.identity(proposal_log_prob_fn(*args),
                                name="proposal_log_prob")
                t = tf.identity(target_log_prob_fn(*args),
                                name="target_log_prob")
                dtype = p.dtype.base_dtype
                beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
                return tf.identity(beta * t + (1. - beta) * p,
                                   name="convex_combined_log_prob")

            return _fn

        def _loop_body(iter_, ais_weights, current_state, kernel_results):
            """Closure which implements `tf.while_loop` body."""
            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            proposal_log_prob = proposal_log_prob_fn(*x)
            target_log_prob = target_log_prob_fn(*x)
            ais_weights += ((target_log_prob - proposal_log_prob) /
                            tf.cast(num_steps, ais_weights.dtype))
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
            next_state, inner_results = kernel.one_step(
                current_state, kernel_results.inner_results)
            kernel_results = AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )
            return [iter_ + 1, ais_weights, next_state, kernel_results]

        def _bootstrap_results(init_state):
            """Creates first version of `previous_kernel_results`."""
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0))
            inner_results = kernel.bootstrap_results(init_state)

            convex_combined_log_prob = inner_results.accepted_results.target_log_prob
            dtype = convex_combined_log_prob.dtype.as_numpy_dtype
            shape = tf.shape(input=convex_combined_log_prob)
            proposal_log_prob = tf.fill(shape,
                                        dtype(np.nan),
                                        name="bootstrap_proposal_log_prob")
            target_log_prob = tf.fill(shape,
                                      dtype(np.nan),
                                      name="target_target_log_prob")

            return AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )

        previous_kernel_results = _bootstrap_results(current_state)
        inner_results = previous_kernel_results.inner_results

        ais_weights = tf.zeros(shape=tf.broadcast_dynamic_shape(
            tf.shape(input=inner_results.proposed_results.target_log_prob),
            tf.shape(input=inner_results.accepted_results.target_log_prob)),
                               dtype=inner_results.proposed_results.
                               target_log_prob.dtype.base_dtype)

        [_, ais_weights, current_state, kernel_results] = tf.while_loop(
            cond=lambda iter_, *args: iter_ < num_steps,
            body=_loop_body,
            loop_vars=[
                np.int32(0),  # iter_
                ais_weights,
                current_state,
                previous_kernel_results,
            ],
            parallel_iterations=parallel_iterations)

        return [current_state, ais_weights, kernel_results]
Ejemplo n.º 28
0
 def build_assign_op():
     if mcmc_util.is_list_like(step_size_var):
         return [ss.assign_add(ss * adjustment) for ss in step_size_var]
     return step_size_var.assign_add(step_size_var * adjustment)
Ejemplo n.º 29
0
def sample_chain(
    num_results,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    num_burnin_steps=0,
    num_steps_between_results=0,
    parallel_iterations=10,
    name=None):
  """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from an Markov chain at `current_state` and whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  This function can sample from multiple chains, in parallel. (Whether or not
  there are multiple chains is dictated by the `kernel`.)

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such "thinning"
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized (in calls to
  `sess.run`), and thus do not increase memory requirements.

  Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s
  `parallel_iterations=1`, otherwise results will not be reproducible.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
      `list` of `Tensor`s representing internal calculations made within the
      previous call to this function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "mcmc_sample_chain").

  Returns:
    next_states: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state` but with a prepended `num_results`-size dimension.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.
    kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
      `Tensor`s representing internal calculations made within this function.

  #### Examples

  ##### Sample from a diagonal-variance Gaussian.

  ```python
  import tensorflow tf
  import tensorflow_probability as tfp
  tfd = tf.contrib.distributions

  def make_likelihood(true_variances):
    return tfd.MultivariateNormalDiag(
        scale_diag=tf.sqrt(true_variances))

  dims = 10
  dtype = np.float32
  true_variances = tf.linspace(dtype(1), dtype(3), dims)
  likelihood = make_likelihood(true_variances)

  states, kernel_results = tfp.mcmc.sample_chain(
      num_results=1000,
      current_state=tf.zeros(dims),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(states, axis=0)
  sample_var = tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0)
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  for i=1..n:
    w[i] ~ Normal(0, eye(d))            # prior
    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
  ```

  where `F` denotes factors.

  ```python
  import tensorflow tf
  import tensorflow_probability as tfp
  tfd = tf.contrib.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))

  # Setup data.
  num_weights = 10
  num_factors = 4
  num_chains = 100
  dtype = np.float32

  prior = make_prior(num_weights, dtype)
  weights = prior.sample(num_chains)
  factors = np.random.randn(num_factors, num_weights).astype(dtype)
  x = make_likelihood(weights, factors).sample(num_chains)

  def target_log_prob(w):
    # Target joint is: `f(w) = p(w, x | factors)`.
    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = tfp.mcmc.sample_chain(
      num_results=1000,
      current_state=tf.zeros([num_chains, dims], dtype),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=target_log_prob,
        step_size=0.1,
        num_leapfrog_steps=2),
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  ```

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
  if not kernel.is_calibrated:
    warnings.warn("Supplied `TransitionKernel` is not calibrated. Markov "
                  "chain may not converge to intended target distribution.")
  with tf.name_scope(
      name, "mcmc_sample_chain",
      [num_results, num_burnin_steps, num_steps_between_results]):
    num_results = tf.convert_to_tensor(
        num_results,
        dtype=tf.int32,
        name="num_results")
    num_burnin_steps = tf.convert_to_tensor(
        num_burnin_steps,
        dtype=tf.int32,
        name="num_burnin_steps")
    num_steps_between_results = tf.convert_to_tensor(
        num_steps_between_results,
        dtype=tf.int32,
        name="num_steps_between_results")

    if mcmc_util.is_list_like(current_state):
      current_state = [tf.convert_to_tensor(s, name="current_state")
                       for s in current_state]
    else:
      current_state = tf.convert_to_tensor(current_state, name="current_state")

    def _scan_body(args_list, num_steps):
      """Closure which implements `tf.scan` body."""
      current_state, previous_kernel_results = args_list
      return tf.while_loop(
          cond=lambda it_, *args: it_ < num_steps,
          body=lambda it_, cs, pkr: [it_ + 1] + list(kernel.one_step(cs, pkr)),
          loop_vars=[
              np.int32(0),  # it_
              current_state,
              previous_kernel_results,
          ],
          parallel_iterations=parallel_iterations)[1:]  # Lop off `it_`.

    if previous_kernel_results is None:
      previous_kernel_results = kernel.bootstrap_results(current_state)
    return tf.scan(
        fn=_scan_body,
        elems=tf.one_hot(indices=0,
                         depth=num_results,
                         on_value=1 + num_burnin_steps,
                         off_value=1 + num_steps_between_results,
                         dtype=tf.int32),  # num_steps
        initializer=[current_state, previous_kernel_results],
        parallel_iterations=parallel_iterations)
Ejemplo n.º 30
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'mh',
                                                    'one_step'),
                           values=[current_state, previous_kernel_results]):
            # Take one inner step.
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.log(
                tf.random_uniform(
                    shape=tf.shape(proposed_results.target_log_prob),
                    dtype=proposed_results.target_log_prob.dtype.base_dtype,
                    seed=self._seed_stream()))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    proposed_results,
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
            )

            return next_state, kernel_results
Ejemplo n.º 31
0
  def bootstrap_results(self, init_state=None, transformed_init_state=None):
    """Returns an object with the same type as returned by `one_step`.

    Unlike other `TransitionKernel`s,
    `TransformedTransitionKernel.bootstrap_results` has the option of
    initializing the `TransformedTransitionKernelResults` from either an initial
    state, eg, requiring computing `bijector.inverse(init_state)`, or
    directly from `transformed_init_state`, i.e., a `Tensor` or list
    of `Tensor`s which is interpretted as the `bijector.inverse`
    transformed state.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the a
        state(s) of the Markov chain(s). Must specify `init_state` or
        `transformed_init_state` but not both.
      transformed_init_state: `Tensor` or Python `list` of `Tensor`s
        representing the a state(s) of the Markov chain(s). Must specify
        `init_state` or `transformed_init_state` but not both.

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".

    #### Examples

    To use `transformed_init_state` in context of
    `tfp.mcmc.sample_chain`, you need to explicitly pass the
    `previous_kernel_results`, e.g.,

    ```python
    transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
    init_state = ...        # Doesnt matter.
    transformed_init_state = ... # Does matter.
    results, _ = tfp.mcmc.sample_chain(
        num_results=...,
        current_state=init_state,
        previous_kernel_results=transformed_kernel.bootstrap_results(
            transformed_init_state=transformed_init_state),
        kernel=transformed_kernel)
    ```
    """
    if (init_state is None) == (transformed_init_state is None):
      raise ValueError('Must specify exactly one of `init_state` '
                       'or `transformed_init_state`.')
    with tf.name_scope(
        name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'),
        values=[init_state, transformed_init_state]):
      if transformed_init_state is None:
        init_state_parts = (init_state if is_list_like(init_state)
                            else [init_state])
        transformed_init_state_parts = self._inverse_transform(init_state_parts)
        transformed_init_state = (
            transformed_init_state_parts
            if is_list_like(init_state) else transformed_init_state_parts[0])
      else:
        if is_list_like(transformed_init_state):
          transformed_init_state = [
              tf.convert_to_tensor(s, name='transformed_init_state')
              for s in transformed_init_state
          ]
        else:
          transformed_init_state = tf.convert_to_tensor(
              transformed_init_state, name='transformed_init_state')
      kernel_results = TransformedTransitionKernelResults(
          transformed_state=transformed_init_state,
          inner_results=self._inner_kernel.bootstrap_results(
              transformed_init_state))
      return kernel_results
Ejemplo n.º 32
0
def _prepare_args(target_log_prob_fn,
                  volatility_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  volatility=None,
                  grads_volatility_fn=None,
                  diffusion_drift=None,
                  parallel_iterations=10):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts = list(state) if mcmc_util.is_list_like(state) else [state]

    [
        target_log_prob,
        grads_target_log_prob,
    ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts,
                                          target_log_prob,
                                          grads_target_log_prob)
    [
        volatility_parts,
        grads_volatility,
    ] = _maybe_call_volatility_fn_and_grads(
        volatility_fn, state_parts, volatility, grads_volatility_fn,
        distribution_util.prefer_static_shape(target_log_prob),
        parallel_iterations)

    step_sizes = (list(step_size)
                  if mcmc_util.is_list_like(step_size) else [step_size])
    step_sizes = [
        tf.convert_to_tensor(value=s,
                             name='step_size',
                             dtype=target_log_prob.dtype) for s in step_sizes
    ]
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    if diffusion_drift is None:
        diffusion_drift_parts = _get_drift(step_sizes, volatility_parts,
                                           grads_volatility,
                                           grads_target_log_prob)
    else:
        diffusion_drift_parts = (list(diffusion_drift)
                                 if mcmc_util.is_list_like(diffusion_drift)
                                 else [diffusion_drift])
        if len(state_parts) != len(diffusion_drift):
            raise ValueError(
                'There should be exactly one `diffusion_drift` or it '
                'should have same length as list-like `current_state`.')

    return [
        state_parts,
        step_sizes,
        target_log_prob,
        grads_target_log_prob,
        volatility_parts,
        grads_volatility,
        diffusion_drift_parts,
    ]
def sample_annealed_importance_chain(
    num_steps,
    proposal_log_prob_fn,
    target_log_prob_fn,
    current_state,
    make_kernel_fn,
    parallel_iterations=10,
    name=None):
  """Runs annealed importance sampling (AIS) to estimate normalizing constants.

  This function uses Hamiltonian Monte Carlo to sample from a series of
  distributions that slowly interpolates between an initial "proposal"
  distribution:

  `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

  and the target distribution:

  `exp(target_log_prob_fn(x) - target_log_normalizer)`,

  accumulating importance weights along the way. The product of these
  importance weights gives an unbiased estimate of the ratio of the
  normalizing constants of the initial distribution and the target
  distribution:

  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

  Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three
  times (although this may be reduced to two times, in the future).

  Args:
    num_steps: Integer number of Markov chain updates to run. More
      iterations means more expense, but smoother annealing between q
      and p, which in turn means exponentially lower variance for the
      normalizing constant estimator.
    proposal_log_prob_fn: Python callable that returns the log density of the
      initial distribution.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which
    is an interpolation between the supplied `target_log_prob_fn` and
    `proposal_log_prob_fn`; it is this interpolated function which is used as an
    argument to `make_kernel_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "sample_annealed_importance_chain").

  Returns:
    next_state: `Tensor` or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at the final iteration. Has same shape as
      input `current_state`.
    ais_weights: Tensor with the estimated weight(s). Has shape matching
      `target_log_prob_fn(current_state)`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  #### Examples

  ##### Estimate the normalizing constant of a log-gamma distribution.

  ```python
  tfd = tfp.distributions

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 20
  dtype = np.float32

  proposal = tfd.MultivatiateNormalDiag(
     loc=tf.zeros([dims], dtype=dtype))

  target = tfd.TransformedDistribution(
    distribution=tfd.Gamma(concentration=dtype(2),
                           rate=dtype(3)),
    bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()),
    event_shape=[dims])

  chains_state, ais_weights, kernels_results = (
      tfp.mcmc.sample_annealed_importance_chain(
          num_steps=1000,
          proposal_log_prob_fn=proposal.log_prob,
          target_log_prob_fn=target.log_prob,
          current_state=proposal.sample(num_chains),
          make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn,
            step_size=0.2,
            num_leapfrog_steps=2)))

  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                              - np.log(num_chains))
  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
  ```

  ##### Estimate marginal likelihood of a Bayesian regression model.

  ```python
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, x):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 10
  dtype = np.float32

  # Make training data.
  x = np.random.randn(num_chains, dims).astype(dtype)
  true_weights = np.random.randn(dims).astype(dtype)
  y = np.dot(x, true_weights) + np.random.randn(num_chains)

  # Setup model.
  prior = make_prior(dims, dtype)
  def target_log_prob_fn(weights):
    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

  proposal = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

  weight_samples, ais_weights, kernel_results = (
      tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target_log_prob_fn
        current_state=tf.zeros([num_chains, dims], dtype),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.1,
          num_leapfrog_steps=2)))
  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                             - np.log(num_chains))
  ```

  """
  with tf.name_scope(
      name, "sample_annealed_importance_chain",
      [num_steps, current_state]):
    num_steps = tf.convert_to_tensor(
        num_steps,
        dtype=tf.int32,
        name="num_steps")
    if mcmc_util.is_list_like(current_state):
      current_state = [tf.convert_to_tensor(s, name="current_state")
                       for s in current_state]
    else:
      current_state = tf.convert_to_tensor(
          current_state, name="current_state")

    def _make_convex_combined_log_prob_fn(iter_):
      def _fn(*args):
        p = tf.identity(proposal_log_prob_fn(*args), name="proposal_log_prob")
        t = tf.identity(target_log_prob_fn(*args), name="target_log_prob")
        dtype = p.dtype.base_dtype
        beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
        return tf.identity(beta * t + (1. - beta) * p,
                           name="convex_combined_log_prob")
      return _fn

    def _loop_body(iter_, ais_weights, current_state, kernel_results):
      """Closure which implements `tf.while_loop` body."""
      x = (current_state if mcmc_util.is_list_like(current_state)
           else [current_state])
      proposal_log_prob = proposal_log_prob_fn(*x)
      target_log_prob = target_log_prob_fn(*x)
      ais_weights += ((target_log_prob - proposal_log_prob) /
                      tf.cast(num_steps, ais_weights.dtype))
      kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
      next_state, inner_results = kernel.one_step(
          current_state, kernel_results.inner_results)
      kernel_results = AISResults(
          proposal_log_prob=proposal_log_prob,
          target_log_prob=target_log_prob,
          inner_results=inner_results,
      )
      return [iter_ + 1, ais_weights, next_state, kernel_results]

    def _bootstrap_results(init_state):
      """Creates first version of `previous_kernel_results`."""
      kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0))
      inner_results = kernel.bootstrap_results(init_state)

      convex_combined_log_prob = inner_results.accepted_results.target_log_prob
      dtype = convex_combined_log_prob.dtype.as_numpy_dtype
      shape = tf.shape(convex_combined_log_prob)
      proposal_log_prob = tf.fill(shape, dtype(np.nan),
                                  name="bootstrap_proposal_log_prob")
      target_log_prob = tf.fill(shape, dtype(np.nan),
                                name="target_target_log_prob")

      return AISResults(
          proposal_log_prob=proposal_log_prob,
          target_log_prob=target_log_prob,
          inner_results=inner_results,
      )

    previous_kernel_results = _bootstrap_results(current_state)
    inner_results = previous_kernel_results.inner_results

    ais_weights = tf.zeros(
        shape=tf.broadcast_dynamic_shape(
            tf.shape(inner_results.proposed_results.target_log_prob),
            tf.shape(inner_results.accepted_results.target_log_prob)),
        dtype=inner_results.proposed_results.target_log_prob.dtype.base_dtype)

    [_, ais_weights, current_state, kernel_results] = tf.while_loop(
        cond=lambda iter_, *args: iter_ < num_steps,
        body=_loop_body,
        loop_vars=[
            np.int32(0),  # iter_
            ais_weights,
            current_state,
            previous_kernel_results,
        ],
        parallel_iterations=parallel_iterations)

    return [current_state, ais_weights, kernel_results]
Ejemplo n.º 34
0
 def maybe_flatten(x):
     return x if mcmc_util.is_list_like(current_state) else x[0]