Exemple #1
0
def _init_momentum(initial_transformed_position, *, batch_shape):
  """Initialize momentum so trace_fn can be concatenated."""
  variance_parts = [ps.ones_like(p) for p in initial_transformed_position]
  return preconditioning_utils.make_momentum_distribution(
      state_parts=initial_transformed_position,
      batch_shape=batch_shape,
      running_variance_parts=variance_parts)
def _init_momentum(initial_transformed_position):
  """Initialize momentum so trace_fn can be concatenated."""
  event_shape = ps.shape(initial_transformed_position)[-1]
  return preconditioning_utils.make_momentum_distribution(
      state_parts=tf.nest.flatten(initial_transformed_position),
      batch_ndims=1,
      running_variance_parts=[ps.ones(event_shape)])
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            if isinstance(self.initial_running_variance,
                          sample_stats.RunningVariance):
                variance_parts = [self.initial_running_variance]
            else:
                variance_parts = list(self.initial_running_variance)

            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]

            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)
            # Set the momentum.
            batch_shape = ps.shape(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = preconditioning_utils.make_momentum_distribution(
                init_state_parts, batch_shape, diags)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            return DiagonalMassMatrixAdaptationResults(
                inner_results=inner_results, running_variance=variance_parts)
 def test_momentum_dists(self):
   state_parts = [
       tf.ones([13, 5, 3]), tf.ones([13, 5]), tf.ones([13, 5, 2, 4])]
   batch_shape = [13, 5]
   md = pu.make_momentum_distribution(state_parts, batch_shape)
   md = pu.update_momentum_distribution(
       md,
       tf.nest.map_structure(
           lambda s: tf.reduce_sum(s, (0, 1)), state_parts))
   self.evaluate(tf.nest.flatten(md, expand_composites=True))
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      state_parts, _ = mcmc_util.prepare_state_parts(init_state,
                                                     name='current_state')
      current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
          self.target_log_prob_fn, state_parts)
      # Confirm that the step size is compatible with the state parts.
      _ = _prepare_step_size(
          self.step_size, current_target_log_prob.dtype, len(init_state))
      momentum_distribution = self.momentum_distribution
      if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts, ps.shape(current_target_log_prob),
            shard_axis_names=self.experimental_shard_axis_names)
      momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
          momentum_distribution, ps.shape(current_target_log_prob))
      momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed())

      return PreconditionedNUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          step_size=tf.nest.map_structure(
              lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                  x,
                  dtype=current_target_log_prob.dtype,
                  name='step_size'),
              self.step_size),
          log_accept_ratio=tf.zeros_like(
              current_target_log_prob, name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(
              current_target_log_prob,
              dtype=TREE_COUNT_DTYPE,
              name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='is_accepted'),
          reach_max_depth=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='reach_max_depth'),
          has_divergence=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='has_divergence'),
          energy=compute_hamiltonian(current_target_log_prob, momentum_parts,
                                     momentum_distribution),
          momentum_distribution=momentum_distribution,
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Exemple #6
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False,
                  experimental_shard_axis_names=None):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None
    if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts,
            ps.shape(target_log_prob),
            shard_axis_names=experimental_shard_axis_names)
    momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
        momentum_distribution, ps.shape(target_log_prob))

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
Exemple #7
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)

            # Bootstrap the results.
            results = self._bootstrap_from_inner_results(
                init_state, inner_results)
            if self.num_estimation_steps is not None:
                # We only update the momentum at the end of adaptation phase,
                # so we do not need to set the momentum here.
                return results

            # Set the momentum.
            diags = [
                variance_part.variance()
                for variance_part in results.running_variance
            ]
            inner_results = results.inner_results
            batch_shape = ps.shape(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = preconditioning_utils.make_momentum_distribution(
                init_state_parts,
                batch_shape,
                diags,
                shard_axis_names=self.experimental_shard_axis_names)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            results = results._replace(inner_results=inner_results)
            return results
Exemple #8
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'phmc', 'bootstrap_results')):
            result = super(UncalibratedPreconditionedHamiltonianMonteCarlo,
                           self).bootstrap_results(init_state)

            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            target_log_prob = self.target_log_prob_fn(*state_parts)
            if (not self._store_parameters_in_results
                    or self.momentum_distribution is None):
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(target_log_prob))
            else:
                momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                    self.momentum_distribution, ps.shape(target_log_prob))
            result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults(
                **result._asdict(),  # pylint: disable=protected-access
                momentum_distribution=momentum_distribution)
        return result
Exemple #9
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            # Padding the step_size so it is compatable with the states
            step_size = self.step_size
            if len(step_size) == 1:
                step_size = step_size * len(init_state)
            if len(step_size) != len(init_state):
                raise ValueError(
                    'Expected either one step size or {} (size of '
                    '`init_state`), but found {}'.format(
                        len(init_state), len(step_size)))
            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
                self.target_log_prob_fn, state_parts)
            momentum_distribution = self.momentum_distribution
            if momentum_distribution is None:
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(current_target_log_prob))
            momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                momentum_distribution, ps.shape(current_target_log_prob))
            momentum_parts = momentum_distribution.sample()

            def _init(shape_and_dtype):
                """Allocate TensorArray for storing state and velocity."""
                return [  # pylint: disable=g-complex-comprehension
                    ps.zeros(ps.concat([[max(self._write_instruction) + 1], s],
                                       axis=0),
                             dtype=d) for (s, d) in shape_and_dtype
                ]

            get_shapes_and_dtypes = lambda x: [
                (ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                for x_ in x
            ]
            velocity_state_memory = VelocityStateSwap(
                velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)),
                state_swap=_init(get_shapes_and_dtypes(init_state)))

            return PreconditionedNUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                velocity_state_memory=velocity_state_memory,
                step_size=step_size,
                log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                               name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                              dtype=TREE_COUNT_DTYPE,
                                              name='leapfrogs_taken'),
                is_accepted=tf.zeros_like(current_target_log_prob,
                                          dtype=tf.bool,
                                          name='is_accepted'),
                reach_max_depth=tf.zeros_like(current_target_log_prob,
                                              dtype=tf.bool,
                                              name='reach_max_depth'),
                has_divergence=tf.zeros_like(current_target_log_prob,
                                             dtype=tf.bool,
                                             name='has_divergence'),
                energy=compute_hamiltonian(current_target_log_prob,
                                           momentum_parts,
                                           momentum_distribution),
                momentum_distribution=momentum_distribution,
                # Allow room for one_step's seed.
                seed=samplers.zeros_seed(),
            )
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            state_parts = tf.nest.flatten(current_state)
            new_momentum_distribution = preconditioning_utils.make_momentum_distribution(
                state_parts, batch_ndims, diags)
            inner_results = self.momentum_distribution_setter_fn(
                previous_kernel_results.inner_results,
                new_momentum_distribution)

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)
            new_state_parts = tf.nest.flatten(new_state)
            new_variance_parts = []
            for variance_part, diag, state_part in zip(variance_parts, diags,
                                                       new_state_parts):
                # Compute new variance for each variance part, accounting for partial
                # batching of the variance calculation across chains (ie, some, all, or
                # none of the chains may share the estimated mass matrix).
                #
                # For example, say
                #
                # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                # variance_part has shape          [4] + [5, 6]
                # log_prob has shape         [2, 3, 4]
                #
                # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                # matrices, each being shared across a [2, 3]-batch of chains. Note this
                # division is inferred from the shapes of the state part, the log_prob,
                # and the user-provided initial running variances.
                #
                # Until RunningVariance supports rank > 1 chunking, we need to flatten
                # the states that go into updating the variance estimates. In the above
                # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                # fed to `RunningVariance.update(state_part, axis=0)`, recording
                # 6 new observations in the running variance calculation.
                # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                # the resulting momentum distribution will have batch shape of
                # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                state_rank = ps.rank(state_part)
                variance_rank = ps.rank(diag)
                num_reduce_dims = state_rank - variance_rank

                state_part_shape = ps.shape(state_part)
                # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
                # dimensions to a single one otherwise.
                reshaped_state = ps.reshape(
                    state_part,
                    ps.concat(
                        [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
                         state_part_shape[num_reduce_dims:]],
                        axis=0))

                # The `axis=0` here removes the leading dimension we got from the
                # reshape above, so the new_variance_parts have the correct shape again.
                new_variance_parts.append(
                    variance_part.update(reshaped_state, axis=0))

            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts)

            return new_state, new_kernel_results