Пример #1
0
  def _start_trajectory_batched(self, state, target_log_prob, seed):
    """Computations needed to start a trajectory."""
    with tf.name_scope('start_trajectory_batched'):
      seeds = samplers.split_seed(seed, n=len(state) + 1)
      momentum_seeds = distribute_lib.fold_in_axis_index(
          seeds[:-1], self.experimental_shard_axis_names)
      momentum = [
          samplers.normal(  # pylint: disable=g-complex-comprehension
              shape=ps.shape(x),
              dtype=x.dtype,
              seed=momentum_seeds[i]) for (i, x) in enumerate(state)
      ]
      init_energy = compute_hamiltonian(
          target_log_prob, momentum,
          shard_axis_names=self.experimental_shard_axis_names)

      if MULTINOMIAL_SAMPLE:
        return momentum, init_energy, None

      # Draw a slice variable u ~ Uniform(0, p(initial state, initial
      # momentum)) and compute log u. For numerical stability, we perform this
      # in log space where log u = log (u' * p(...)) = log u' + log
      # p(...) and u' ~ Uniform(0, 1).
      log_slice_sample = tf.math.log1p(-samplers.uniform(
          shape=ps.shape(init_energy),
          dtype=init_energy.dtype,
          seed=seeds[len(state)]))
      return momentum, init_energy, log_slice_sample
Пример #2
0
def _choose_random_direction(current_state_parts, batch_rank, seed=None,
                             experimental_shard_axis_names=None):
  """Chooses a random direction in the event space."""
  seeds = list(samplers.split_seed(seed, n=len(current_state_parts)))
  seeds = distribute_lib.fold_in_axis_index(
      seeds, experimental_shard_axis_names)
  # Sample random directions across each of the input components.
  def _sample_direction_part(state_part, part_seed):
    state_part_shape = ps.shape(state_part)
    batch_shape = state_part_shape[:batch_rank]
    dimension = ps.reduce_prod(state_part_shape[batch_rank:])
    return ps.reshape(
        random_ops.spherical_uniform(
            shape=batch_shape,
            dimension=dimension,
            dtype=state_part.dtype,
            seed=part_seed),
        state_part_shape)
  return [_sample_direction_part(state_part, seed)
          for state_part, seed in zip(current_state_parts, seeds)]
    def _fn(state_parts, seed, experimental_shard_axis_names=None):
        """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
      experimental_shard_axis_names: A structure of string names indicating how
        members of the state are sharded.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_normal_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')

            part_seeds = samplers.split_seed(seed, n=len(state_parts))
            part_seeds = distribute_lib.fold_in_axis_index(
                part_seeds, experimental_shard_axis_names)

            next_state_parts = [
                samplers.normal(  # pylint: disable=g-complex-comprehension
                    mean=state_part,
                    stddev=scale_part,
                    shape=ps.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_part)
                for scale_part, state_part, seed_part in zip(
                    scales, state_parts, part_seeds)
            ]

            return next_state_parts
Пример #4
0
    def _fn(state_parts, seed, experimental_shard_axis_names=None):
        """Adds a uniform perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
      experimental_shard_axis_names: A structure of string names indicating how
        members of the state are sharded.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_uniform_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')

            part_seeds = list(samplers.split_seed(seed, n=len(state_parts)))
            part_seeds = distribute_lib.fold_in_axis_index(
                part_seeds, experimental_shard_axis_names)

            next_state_parts = [
                samplers.uniform(  # pylint: disable=g-complex-comprehension
                    minval=state_part - scale_part,
                    maxval=state_part + scale_part,
                    shape=tf.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_part)
                for scale_part, state_part, seed_part in zip(
                    scales, state_parts, part_seeds)
            ]
            return next_state_parts
Пример #5
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
                seeds = list(
                    samplers.split_seed(seed,
                                        n=len(current_state_parts),
                                        salt='langevin.one_step'))
                seeds = distribute_lib.fold_in_axis_index(
                    seeds, self.experimental_shard_axis_names)

                random_draw_parts = []
                for state_part, part_seed in zip(current_state_parts, seeds):
                    random_draw_parts.append(
                        samplers.normal(shape=ps.shape(state_part),
                                        dtype=dtype_util.base_dtype(
                                            state_part.dtype),
                                        seed=part_seed))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = ps.rank(current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts,
                next_state_parts,
                current_volatility_parts,
                next_volatility_parts,
                current_drift_parts,
                next_drift_parts,
                step_size_parts,
                independent_chain_ndims,
                experimental_shard_axis_names=self.
                experimental_shard_axis_names)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                pred=self.compute_acceptance,
                true_fn=lambda: log_acceptance_correction_compute,
                false_fn=lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts,
                    seed=seed,
                ),
            ]
Пример #6
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            seeds = samplers.split_seed(seed, n=len(current_state_parts))
            seeds = distribute_lib.fold_in_axis_index(
                seeds, self.experimental_shard_axis_names)

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=ps.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = ps.rank(current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts,
                    next_momentum_parts,
                    independent_chain_ndims,
                    shard_axis_names=self.experimental_shard_axis_names),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Пример #7
0
 def _sample_n(self, n, seed, **kwargs):
   seed = samplers.sanitize_seed(seed, salt='sharded_sample')
   seed = distribute_lib.fold_in_axis_index(
       seed, self.experimental_shard_axis_names)
   return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
Пример #8
0
 def one_step(self, current_state, previous_kernel_results, seed=None):
   seed = samplers.sanitize_seed(seed, salt='sharded_kernel')
   seed = distribute_lib.fold_in_axis_index(seed, self.chain_axis_names)
   return self.inner_kernel.one_step(
       current_state, previous_kernel_results, seed=seed)