Ejemplo n.º 1
0
  def _sample_n(self, n, seed=None):
    low = tf.convert_to_tensor(self.low)
    high = tf.convert_to_tensor(self.high)
    peak = tf.convert_to_tensor(self.peak)

    seed = samplers.sanitize_seed(seed, salt='triangular')
    shape = ps.concat([[n], self._batch_shape_tensor(
        low=low, high=high, peak=peak)], axis=0)
    samples = samplers.uniform(shape=shape, dtype=self.dtype, seed=seed)
    # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
    # we must use sqrts here.
    interval_length = high - low
    return tf.where(
        # Note the CDF on the left side of the peak is
        # (x - low) ** 2 / ((high - low) * (peak - low)).
        # If we plug in peak for x, we get that the CDF at the peak
        # is (peak - low) / (high - low). Because of this we decide
        # which part of the piecewise CDF we should use based on the cdf samples
        # we drew.
        samples < (peak - low) / interval_length,
        # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
        low + tf.sqrt(samples * interval_length * (peak - low)),
        # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
        high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
Ejemplo n.º 2
0
def _random_gamma_no_gradient(
    shape, concentration, rate, log_rate, seed, log_space):
  """Sample a gamma, CPU specialized to stateless_gamma.

  Args:
    shape: Sample shape.
    concentration: Concentration of gamma distribution.
    rate: Rate parameter of gamma distribution.
    log_rate: Log-rate parameter of gamma distribution.
    seed: int or Tensor seed.
    log_space: If `True`, draw log-of-gamma samples.

  Returns:
    samples: Samples from gamma distributions.
  """
  seed = samplers.sanitize_seed(seed)

  sampler_impl = implementation_selection.implementation_selecting(
      fn_name='gamma',
      default_fn=_random_gamma_noncpu,
      cpu_fn=_random_gamma_cpu)
  return sampler_impl(
      shape=shape, concentration=concentration, rate=rate, log_rate=log_rate,
      seed=seed, log_space=log_space)
Ejemplo n.º 3
0
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None):
    """Samples from the standardized von Mises distribution.

  The distribution is vonMises(loc=0, concentration=concentration), so the mean
  is zero.
  The location can then be changed by adding it to the samples.

  The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1].
  The samples are pathwise differentiable using the approach of [2].

  Args:
    shape: The output sample shape.
    concentration: The concentration parameter of the von Mises distribution.
    dtype: The data type of concentration and the outputs.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    Differentiable samples of standardized von Mises.

  References:
    [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag,
    1986; Chapter 9, p. 473-476.
    http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
    + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf
    [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit
    Reparameterization Gradients", 2018.
  """
    shape = ps.convert_to_shape_tensor(shape,
                                       dtype_hint=tf.int32,
                                       name='shape')
    seed = samplers.sanitize_seed(seed, salt='von_mises')
    concentration = tf.convert_to_tensor(concentration,
                                         dtype=dtype,
                                         name='concentration')

    return _von_mises_sample_with_gradient(shape, concentration, seed)
Ejemplo n.º 4
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')):
            with tf.name_scope('initialize'):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            next_state_parts = self.new_state_fn(current_state_parts, seed)  # pylint: disable=not-callable

            # User should be using a new_state_fn that does not alter the state size.
            # This will fail noisily if that is not the case.
            for next_part, current_part in zip(next_state_parts,
                                               current_state_parts):
                tensorshape_util.set_shape(next_part, current_part.shape)

            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)  # pylint: disable=not-callable

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros_like(
                        next_target_log_prob),
                    target_log_prob=next_target_log_prob,
                    seed=seed,
                ),
            ]
Ejemplo n.º 5
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
                seeds = list(
                    samplers.split_seed(seed,
                                        n=len(current_state_parts),
                                        salt='langevin.one_step'))
                seeds = distribute_lib.fold_in_axis_index(
                    seeds, self.experimental_shard_axis_names)

                random_draw_parts = []
                for state_part, part_seed in zip(current_state_parts, seeds):
                    random_draw_parts.append(
                        samplers.normal(shape=ps.shape(state_part),
                                        dtype=dtype_util.base_dtype(
                                            state_part.dtype),
                                        seed=part_seed))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = ps.rank(current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts,
                next_state_parts,
                current_volatility_parts,
                next_volatility_parts,
                current_drift_parts,
                next_drift_parts,
                step_size_parts,
                independent_chain_ndims,
                experimental_shard_axis_names=self.
                experimental_shard_axis_names)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                pred=self.compute_acceptance,
                true_fn=lambda: log_acceptance_correction_compute,
                false_fn=lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts,
                    seed=seed,
                ),
            ]
Ejemplo n.º 6
0
 def _sample_n(self, n, seed, **kwargs):
     seed = samplers.sanitize_seed(seed, salt='sharded_independent_sample')
     return super(ShardedIndependent,
                  self)._sample_n(n, seed + self.replica_id, **kwargs)
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                momentum_distribution,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                self.momentum_distribution,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)
            current_momentum_parts = momentum_distribution.sample(seed=seed)
            momentum_log_prob = getattr(momentum_distribution,
                                        '_log_prob_unnormalized',
                                        momentum_distribution.log_prob)
            kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)

            # Let the integrator handle the case where no momentum distribution
            # is provided
            if self.momentum_distribution is None:
                leapfrog_kinetic_energy_fn = None
            else:
                leapfrog_kinetic_energy_fn = kinetic_energy_fn

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(
                current_momentum_parts,
                current_state_parts,
                target=current_target_log_prob,
                target_grad_parts=current_target_log_prob_grad_parts,
                kinetic_energy_fn=leapfrog_kinetic_energy_fn)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    kinetic_energy_fn, current_momentum_parts,
                    next_momentum_parts),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Ejemplo n.º 8
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
    start_trajectory_seed, loop_seed = samplers.split_seed(seed)

    with tf.name_scope(self.name + '.one_step'):
      unwrap_state_list = not tf.nest.is_nested(current_state)
      if unwrap_state_list:
        current_state = [current_state]
      momentum_distribution = previous_kernel_results.momentum_distribution

      current_target_log_prob = previous_kernel_results.target_log_prob
      [init_momentum,
       init_energy,
       log_slice_sample] = self._start_trajectory_batched(
           current_state,
           current_target_log_prob,
           momentum_distribution=momentum_distribution,
           seed=start_trajectory_seed)

      def _copy(v):
        return v * ps.ones(
            ps.pad(
                [2], paddings=[[0, ps.rank(v)]], constant_values=1),
            dtype=v.dtype)
      _, init_velocity = mcmc_util.maybe_call_fn_and_grads(
          get_kinetic_energy_fn(momentum_distribution),
          [m + 0 for m in init_momentum])  # Breaks cache.

      initial_state = TreeDoublingState(
          momentum=init_momentum,
          velocity=init_velocity,
          state=current_state,
          target=current_target_log_prob,
          target_grad_parts=previous_kernel_results.grads_target_log_prob)
      initial_step_state = tf.nest.map_structure(_copy, initial_state)

      if MULTINOMIAL_SAMPLE:
        init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
      else:
        init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

      candidate_state = TreeDoublingStateCandidate(
          state=current_state,
          target=current_target_log_prob,
          target_grad_parts=previous_kernel_results.grads_target_log_prob,
          energy=init_energy,
          weight=init_weight)

      initial_step_metastate = TreeDoublingMetaState(
          candidate_state=candidate_state,
          is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
          momentum_sum=init_momentum,
          energy_diff_sum=tf.zeros_like(init_energy),
          leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE),
          continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
          not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

      # Convert the write/read instruction into TensorArray so that it is
      # compatible with XLA.
      write_instruction = tf.TensorArray(
          TREE_COUNT_DTYPE,
          size=len(self._write_instruction),
          clear_after_read=False).unstack(self._write_instruction)
      read_instruction = tf.TensorArray(
          tf.int32,
          size=len(self._read_instruction),
          clear_after_read=False).unstack(self._read_instruction)

      current_step_meta_info = OneStepMetaInfo(
          log_slice_sample=log_slice_sample,
          init_energy=init_energy,
          write_instruction=write_instruction,
          read_instruction=read_instruction
          )

      velocity_state_memory = VelocityStateSwap(
          velocity_swap=self.init_velocity_state_memory(init_momentum),
          state_swap=self.init_velocity_state_memory(current_state))

      step_size = _prepare_step_size(
          previous_kernel_results.step_size,
          current_target_log_prob.dtype,
          len(current_state))
      _, _, _, new_step_metastate = tf.while_loop(
          cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
              (iter_ < self.max_tree_depth) &
              tf.reduce_any(metastate.continue_tree)),
          body=lambda iter_, seed, state, metastate: self._loop_tree_doubling(  # pylint: disable=g-long-lambda
              step_size,
              velocity_state_memory,
              current_step_meta_info,
              iter_,
              state,
              metastate,
              momentum_distribution,
              seed),
          loop_vars=(
              tf.zeros([], dtype=tf.int32, name='iter'),
              loop_seed,
              initial_step_state,
              initial_step_metastate),
          parallel_iterations=self.parallel_iterations,
      )

      kernel_results = PreconditionedNUTSKernelResults(
          target_log_prob=new_step_metastate.candidate_state.target,
          grads_target_log_prob=(
              new_step_metastate.candidate_state.target_grad_parts),
          step_size=previous_kernel_results.step_size,
          log_accept_ratio=tf.math.log(
              new_step_metastate.energy_diff_sum /
              tf.cast(new_step_metastate.leapfrog_count,
                      dtype=new_step_metastate.energy_diff_sum.dtype)),
          leapfrogs_taken=(
              new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps
          ),
          is_accepted=new_step_metastate.is_accepted,
          reach_max_depth=new_step_metastate.continue_tree,
          has_divergence=~new_step_metastate.not_divergence,
          energy=new_step_metastate.candidate_state.energy,
          momentum_distribution=momentum_distribution,
          seed=seed,
      )

      result_state = new_step_metastate.candidate_state.state
      if unwrap_state_list:
        result_state = result_state[0]

      return result_state, kernel_results
Ejemplo n.º 9
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Runs one iteration of the Elliptical Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(log_likelihood_fn(*normal_sampler_fn()))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: Optional seed, for reproducible sampling.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      TypeError: if `not log_likelihood.dtype.is_floating`.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'elliptical_slice',
                                    'one_step')):
            with tf.name_scope('initialize'):
                [init_state_parts, init_log_likelihood
                 ] = _prepare_args(self.log_likelihood_fn, current_state,
                                   previous_kernel_results.log_likelihood)

            seed = samplers.sanitize_seed(
                seed)  # Unsalted, for kernel results.
            normal_seed, u_seed, angle_seed, loop_seed = samplers.split_seed(
                seed, n=4, salt='elliptical_slice_sampler')
            normal_samples = self.normal_sampler_fn(normal_seed)  # pylint: disable=not-callable
            normal_samples = list(normal_samples) if mcmc_util.is_list_like(
                normal_samples) else [normal_samples]
            u = samplers.uniform(
                shape=tf.shape(init_log_likelihood),
                seed=u_seed,
                dtype=init_log_likelihood.dtype.base_dtype,
            )
            threshold = init_log_likelihood + tf.math.log(u)

            starting_angle = samplers.uniform(
                shape=tf.shape(init_log_likelihood),
                minval=0.,
                maxval=2 * np.pi,
                name='angle',
                seed=angle_seed,
                dtype=init_log_likelihood.dtype.base_dtype,
            )
            starting_angle_min = starting_angle - 2 * np.pi
            starting_angle_max = starting_angle

            starting_state_parts = _rotate_on_ellipse(init_state_parts,
                                                      normal_samples,
                                                      starting_angle)
            starting_log_likelihood = self.log_likelihood_fn(
                *starting_state_parts)  # pylint: disable=not-callable

            def chain_not_done(seed, angle, angle_min, angle_max,
                               current_state_parts, current_log_likelihood):
                del seed, angle, angle_min, angle_max, current_state_parts
                return tf.reduce_any(current_log_likelihood < threshold)

            def sample_next_angle(seed, angle, angle_min, angle_max,
                                  current_state_parts, current_log_likelihood):
                """Slice sample a new angle, and rotate init_state by that amount."""
                angle_seed, next_seed = samplers.split_seed(seed)
                chain_not_done = current_log_likelihood < threshold
                # Box in on angle. Only update angles for which we haven't generated a
                # point that beats the threshold.
                angle_min = tf.where((angle < 0) & chain_not_done, angle,
                                     angle_min)
                angle_max = tf.where((angle >= 0) & chain_not_done, angle,
                                     angle_max)
                new_angle = samplers.uniform(
                    shape=tf.shape(current_log_likelihood),
                    minval=angle_min,
                    maxval=angle_max,
                    seed=angle_seed,
                    dtype=angle.dtype.base_dtype)
                angle = tf.where(chain_not_done, new_angle, angle)
                next_state_parts = _rotate_on_ellipse(init_state_parts,
                                                      normal_samples, angle)

                new_state_parts = []
                broadcasted_chain_not_done = _right_pad_with_ones(
                    chain_not_done, tf.rank(next_state_parts[0]))
                for n_state, c_state in zip(next_state_parts,
                                            current_state_parts):
                    new_state_part = tf.where(broadcasted_chain_not_done,
                                              n_state, c_state)
                    new_state_parts.append(new_state_part)

                return (
                    next_seed,
                    angle,
                    angle_min,
                    angle_max,
                    new_state_parts,
                    self.log_likelihood_fn(*new_state_parts)  # pylint: disable=not-callable
                )

            [
                _,
                next_angle,
                _,
                _,
                next_state_parts,
                next_log_likelihood,
            ] = tf.while_loop(cond=chain_not_done,
                              body=sample_next_angle,
                              loop_vars=[
                                  loop_seed, starting_angle,
                                  starting_angle_min, starting_angle_max,
                                  starting_state_parts, starting_log_likelihood
                              ])

            return [
                next_state_parts if mcmc_util.is_list_like(current_state) else
                next_state_parts[0],
                EllipticalSliceSamplerKernelResults(
                    log_likelihood=next_log_likelihood,
                    angle=next_angle,
                    normal_samples=normal_samples,
                    seed=seed,
                ),
            ]
Ejemplo n.º 10
0
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains,
                            proposal_kernel_kwargs, num_adaptation_steps,
                            current_state, dual_averaging_kwargs, trace_fn,
                            return_final_kernel_results, discard_tuning, seed,
                            chain_axis_names, **pins):
    """Runs windowed sampling using either HMC or NUTS as internal sampler."""
    if trace_fn is None:
        trace_fn = lambda *args: ()
        no_trace = True
    else:
        no_trace = False

    if isinstance(n_chains, int):
        n_chains = [n_chains]

    if (tf.executing_eagerly()
            or not control_flow_util.GraphOrParentsInXlaContext(
                tf1.get_default_graph())):
        # A Tensor num_draws argument breaks XLA, which requires static TensorArray
        # trace_fn result allocation sizes.
        num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps)

    if 'num_adaptation_steps' in dual_averaging_kwargs:
        warnings.warn(
            'Dual averaging adaptation will use the value specified in'
            ' the `num_adaptation_steps` argument for its construction,'
            ' hence there is no need to specify it in the'
            ' `dual_averaging_kwargs` argument.')

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    dual_averaging_kwargs['num_adaptation_steps'] = num_adaptation_steps
    dual_averaging_kwargs.setdefault(
        'reduce_fn',
        functools.partial(
            generic_math.reduce_log_harmonic_mean_exp,
            # There is only one log_accept_prob per chain, and we reduce across
            # all chains, so typically the all_gather will be gathering scalars,
            # which should be relatively efficient.
            experimental_allow_all_gather=True))
    # By default, reduce over named axes for step size adaptation
    dual_averaging_kwargs.setdefault('experimental_reduce_chain_axis_names',
                                     chain_axis_names)
    setup_seed, sample_seed = samplers.split_seed(samplers.sanitize_seed(seed),
                                                  n=2)
    (target_log_prob_fn, initial_transformed_position, bijector,
     step_broadcast, batch_shape,
     shard_axis_names) = _setup_mcmc(joint_dist,
                                     n_chains=n_chains,
                                     init_position=current_state,
                                     seed=setup_seed,
                                     **pins)

    if proposal_kernel_kwargs.get('step_size') is None:
        if batch_shape.shape != (0, ):  # Scalar batch has a 0-vector shape.
            raise ValueError(
                'Batch target density must specify init_step_size. Got '
                f'batch shape {batch_shape} from joint {joint_dist}.')

        init_step_size = _get_step_size(initial_transformed_position,
                                        target_log_prob_fn)

    else:
        init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

    proposal_kernel_kwargs.update({
        'target_log_prob_fn':
        target_log_prob_fn,
        'step_size':
        init_step_size,
        'momentum_distribution':
        _init_momentum(initial_transformed_position,
                       batch_shape=ps.concat([n_chains, batch_shape], axis=0),
                       shard_axis_names=shard_axis_names)
    })

    initial_running_variance = [
        sample_stats.RunningVariance.from_stats(  # pylint: disable=g-complex-comprehension
            num_samples=tf.zeros([], part.dtype),
            mean=tf.zeros_like(part),
            variance=tf.ones_like(part))
        for part in initial_transformed_position
    ]
    # TODO(phandu): Consider splitting out warmup and post warmup phases
    # to avoid executing adaptation code during the post warmup phase.
    ret = _do_sampling(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=n_draws if discard_tuning else n_draws +
        num_adaptation_steps,
        num_burnin_steps=num_adaptation_steps if discard_tuning else 0,
        initial_position=initial_transformed_position,
        initial_running_variance=initial_running_variance,
        bijector=bijector,
        trace_fn=trace_fn,
        return_final_kernel_results=return_final_kernel_results,
        chain_axis_names=chain_axis_names,
        shard_axis_names=shard_axis_names,
        seed=sample_seed)

    if return_final_kernel_results:
        draws, trace, fkr = ret
        return sample.CheckpointableStatesAndTrace(
            all_states=bijector.inverse(draws),
            trace=trace,
            final_kernel_results=fkr)
    else:
        draws, trace = ret
        if no_trace:
            return bijector.inverse(draws)
        else:
            return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                         trace=trace)
Ejemplo n.º 11
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')):
      with tf.name_scope('initialize'):
        # Prepare input arguments to be passed to `_euler_method`.
        [
            current_state_parts,
            step_size_parts,
            current_target_log_prob,
            _,  # grads_target_log_prob
            current_volatility_parts,
            _,  # grads_volatility
            current_drift_parts,
        ] = _prepare_args(
            self.target_log_prob_fn,
            self.volatility_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            previous_kernel_results.grads_target_log_prob,
            previous_kernel_results.volatility,
            previous_kernel_results.grads_volatility,
            previous_kernel_results.diffusion_drift,
            self.parallel_iterations)

        # TODO(b/159636942): Clean up after 2020-09-20.
        if seed is not None:
          seed = samplers.sanitize_seed(seed)
        else:
          if self._seed_stream.original_seed is not None:
            warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
          seed = samplers.sanitize_seed(self._seed_stream())
        seeds = samplers.split_seed(
            seed, n=len(current_state_parts), salt='langevin.one_step')

        random_draw_parts = []
        for state_part, part_seed in zip(current_state_parts, seeds):
          random_draw_parts.append(
              samplers.normal(
                  shape=tf.shape(state_part),
                  dtype=dtype_util.base_dtype(state_part.dtype),
                  seed=part_seed))

      # Number of independent chains run by the algorithm.
      independent_chain_ndims = prefer_static.rank(current_target_log_prob)

      # Generate the next state of the algorithm using Euler-Maruyama method.
      next_state_parts = _euler_method(random_draw_parts,
                                       current_state_parts,
                                       current_drift_parts,
                                       step_size_parts,
                                       current_volatility_parts)

      # Compute helper `UncalibratedLangevinKernelResults` to be processed by
      # `_compute_log_acceptance_correction` and in the next iteration of
      # `one_step` function.
      [
          _,  # state_parts
          _,  # step_sizes
          next_target_log_prob,
          next_grads_target_log_prob,
          next_volatility_parts,
          next_grads_volatility,
          next_drift_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          self.volatility_fn,
          next_state_parts,
          step_size_parts,
          parallel_iterations=self.parallel_iterations)

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      # Decide whether to compute the acceptance ratio
      log_acceptance_correction_compute = _compute_log_acceptance_correction(
          current_state_parts,
          next_state_parts,
          current_volatility_parts,
          next_volatility_parts,
          current_drift_parts,
          next_drift_parts,
          step_size_parts,
          independent_chain_ndims)
      log_acceptance_correction_skip = tf.zeros_like(next_target_log_prob)

      log_acceptance_correction = tf.cond(
          pred=self.compute_acceptance,
          true_fn=lambda: log_acceptance_correction_compute,
          false_fn=lambda: log_acceptance_correction_skip)

      return [
          maybe_flatten(next_state_parts),
          UncalibratedLangevinKernelResults(
              log_acceptance_correction=log_acceptance_correction,
              target_log_prob=next_target_log_prob,
              grads_target_log_prob=next_grads_target_log_prob,
              volatility=maybe_flatten(next_volatility_parts),
              grads_volatility=next_grads_volatility,
              diffusion_drift=next_drift_parts,
              seed=seed,
          ),
      ]
Ejemplo n.º 12
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.

        with tf.name_scope(mcmc_util.make_name(self.name, 'slice',
                                               'one_step')):
            with tf.name_scope('initialize'):
                [current_state_parts, step_sizes, current_target_log_prob
                 ] = _prepare_args(self.target_log_prob_fn,
                                   current_state,
                                   self.step_size,
                                   previous_kernel_results.target_log_prob,
                                   maybe_expand=True)

                max_doublings = ps.convert_to_shape_tensor(
                    value=self.max_doublings,
                    dtype=tf.int32,
                    name='max_doublings')

            independent_chain_ndims = ps.rank(current_target_log_prob)

            [
                next_state_parts, next_target_log_prob, bounds_satisfied,
                direction, upper_bounds, lower_bounds
            ] = _sample_next(self.target_log_prob_fn,
                             current_state_parts,
                             step_sizes,
                             max_doublings,
                             current_target_log_prob,
                             independent_chain_ndims,
                             seed=seed,
                             experimental_shard_axis_names=self.
                             experimental_shard_axis_names)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                SliceSamplerKernelResults(
                    target_log_prob=next_target_log_prob,
                    bounds_satisfied=bounds_satisfied,
                    direction=direction,
                    upper_bounds=upper_bounds,
                    lower_bounds=lower_bounds,
                    seed=seed,
                ),
            ]
Ejemplo n.º 13
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        if samplers.is_stateful_seed(seed):
            seed_stream = SeedStream(seed, salt='JointDistributionCoroutine')
            if not self._stateful_to_stateless:
                seed = None
        else:
            seed_stream = None  # We got a stateless seed for seed=.

        # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
        if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
            seed = samplers.sanitize_seed(seed,
                                          salt='JointDistributionCoroutine')
        gen = self._model_coroutine()
        index = 0
        d = next(gen)
        if self._require_root and not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                # Ensure reproducibility even when xs are (partially) set. Always split.
                stateful_sample_seed = None if seed_stream is None else seed_stream(
                )
                if seed is None:
                    stateless_sample_seed = None
                else:
                    stateless_sample_seed, seed = samplers.split_seed(seed)

                if (value is not None and len(value) > index
                        and value[index] is not None):

                    def convert_tree_to_tensor(x, dtype_hint):
                        return tf.convert_to_tensor(x, dtype_hint=dtype_hint)

                    # This signature does not allow kwarg names. Applies
                    # `convert_to_tensor` on the next value.
                    next_value = nest.map_structure_up_to(
                        ds[-1].dtype,  # shallow_tree
                        convert_tree_to_tensor,  # func
                        value[index],  # x
                        ds[-1].dtype)  # dtype_hint
                else:
                    try:
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=(stateful_sample_seed
                                  if stateless_sample_seed is None else
                                  stateless_sample_seed))
                    except TypeError as e:
                        if ('Expected int for argument' not in str(e)
                                and TENSOR_SEED_MSG_PREFIX not in str(e)) or (
                                    stateful_sample_seed is None):
                            raise
                        msg = (
                            'Falling back to stateful sampling for distribution #{index} '
                            '(0-based) of type `{dist_cls}` with component name '
                            '{component_name} and `dist.name` "{dist_name}". Please '
                            'update to use `tf.random.stateless_*` RNGs. This fallback may '
                            'be removed after 20-Dec-2020. ({exc})')
                        component_name = (joint_distribution_lib.
                                          get_explicit_name_for_component(
                                              ds[-1]))
                        if component_name is None:
                            component_name = '[None specified]'
                        else:
                            component_name = '"{}"'.format(component_name)
                        warnings.warn(
                            msg.format(index=index,
                                       component_name=component_name,
                                       dist_name=ds[-1].name,
                                       dist_cls=type(ds[-1]),
                                       exc=str(e)))
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=stateful_sample_seed)

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
Ejemplo n.º 14
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        is_seeded = seed is not None
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        proposal_seed, acceptance_seed = samplers.split_seed(seed)

        with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')):
            # Take one inner step.
            inner_kwargs = dict(seed=proposal_seed) if is_seeded else {}
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results,
                **inner_kwargs)
            if mcmc_util.is_list_like(current_state):
                proposed_state = tf.nest.pack_sequence_as(
                    current_state, proposed_state)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                samplers.uniform(shape=prefer_static.shape(
                    proposed_results.target_log_prob),
                                 dtype=dtype_util.base_dtype(
                                     proposed_results.target_log_prob.dtype),
                                 seed=acceptance_seed))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    # We strip seeds when populating `accepted_results` because unlike
                    # other kernel result fields, seeds are not a per-chain value.
                    # Thus it is impossible to choose between a previously accepted
                    # seed value and a proposed seed, since said choice would need to
                    # be made on a per-chain basis.
                    mcmc_util.strip_seeds(proposed_results),
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
                seed=seed,
            )

            return next_state, kernel_results
Ejemplo n.º 15
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'snaper_hamiltonian_monte_carlo',
                                    'one_step')):
            inner_results = previous_kernel_results.inner_results

            batch_shape = ps.shape(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            reduce_axes = ps.range(0, ps.size(batch_shape))
            step = inner_results.step
            state_ema_points = previous_kernel_results.state_ema_points

            kernel = self._make_kernel(
                batch_shape=batch_shape,
                step=step,
                state_ema_points=state_ema_points,
                state=current_state,
                mean=previous_kernel_results.ema_mean,
                variance=previous_kernel_results.ema_variance,
                principal_component=previous_kernel_results.
                ema_principal_component,
            )

            inner_results = unnest.replace_innermost(
                inner_results,
                momentum_distribution=(
                    kernel.inner_kernel.parameters['momentum_distribution']),  # pylint: disable=protected-access
            )

            seed = samplers.sanitize_seed(seed)
            state_parts, inner_results = kernel.one_step(
                tf.nest.flatten(current_state),
                inner_results,
                seed=seed,
            )

            state = tf.nest.pack_sequence_as(current_state, state_parts)

            state_ema_points, ema_mean, ema_variance = self._update_state_ema(
                reduce_axes=reduce_axes,
                state=state,
                step=step,
                state_ema_points=state_ema_points,
                ema_mean=previous_kernel_results.ema_mean,
                ema_variance=previous_kernel_results.ema_variance,
            )

            (principal_component_ema_points,
             ema_principal_component) = self._update_principal_component_ema(
                 reduce_axes=reduce_axes,
                 state=state,
                 step=step,
                 principal_component_ema_points=(
                     previous_kernel_results.principal_component_ema_points),
                 ema_principal_component=(
                     previous_kernel_results.ema_principal_component),
             )

            kernel_results = previous_kernel_results._replace(
                inner_results=inner_results,
                ema_mean=ema_mean,
                ema_variance=ema_variance,
                state_ema_points=state_ema_points,
                ema_principal_component=ema_principal_component,
                principal_component_ema_points=principal_component_ema_points,
                seed=seed,
            )

            return state, kernel_results
Ejemplo n.º 16
0
 def test_sanitize_tensor_or_tensorlike(self):
     seed = test_util.test_seed(sampler_type='stateless')
     seed1 = samplers.sanitize_seed(seed=self.evaluate(seed))
     seed2 = samplers.sanitize_seed(seed)
     self.assertAllEqual(seed1, seed2)
Ejemplo n.º 17
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')):
            with tf.name_scope('initialize'):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.new_state_fn`.
            # In other words:
            # - If we were given a seed, we sanitize it to stateless, and
            #   if the `new_state_fn` doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If we were not given a seed, we try `new_state_fn` with a stateless
            #   seed.  Rationale: This is the future.
            # - If it fails with a seed incompatibility problem (as best we can
            #   detect from here), we issue a warning and try it again with a
            #   stateful-style seed. Rationale: User code that didn't set seeds
            #   shouldn't suddenly break.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                force_stateless = True
                seed = samplers.sanitize_seed(seed)
            else:
                force_stateless = False
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                stateful_seed = self._seed_stream()
                seed = samplers.sanitize_seed(stateful_seed)
            try:
                next_state_parts = self.new_state_fn(current_state_parts, seed)  # pylint: disable=not-callable
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX
                        not in str(e)) or force_stateless:
                    raise
                msg = (
                    'Falling back to `int` seed for `new_state_fn` {}. Please update '
                    'to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 10-Sep-2020. ({})')
                warnings.warn(msg.format(self.new_state_fn, str(e)))
                seed = None
                next_state_parts = self.new_state_fn(  # pylint: disable=not-callable
                    current_state_parts, stateful_seed)
            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)  # pylint: disable=not-callable

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros_like(
                        next_target_log_prob),
                    target_log_prob=next_target_log_prob,
                    seed=samplers.zeros_seed() if seed is None else seed,
                ),
            ]
Ejemplo n.º 18
0
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            current_state,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  if (tf.executing_eagerly() or
      not control_flow_util.GraphOrParentsInXlaContext(
          tf1.get_default_graph())):
    # A Tensor num_draws argument breaks XLA, which requires static TensorArray
    # trace_fn result allocation sizes.
    num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  (target_log_prob_fn, initial_transformed_position, bijector,
   step_broadcast, batch_shape) = _setup_mcmc(
       joint_dist,
       n_chains=n_chains,
       init_position=current_state,
       seed=setup_seed,
       **pins)

  if proposal_kernel_kwargs.get('step_size') is None:
    if batch_shape.shape != (0,):  # Scalar batch has a 0-vector shape.
      raise ValueError('Batch target density must specify init_step_size. Got '
                       f'batch shape {batch_shape} from joint {joint_dist}.')

    init_step_size = _get_step_size(initial_transformed_position,
                                    target_log_prob_fn)

  else:
    init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': init_step_size,
      'momentum_distribution': _init_momentum(
          initial_transformed_position,
          batch_shape=ps.concat([[n_chains], batch_shape], axis=0))})

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)

  all_traces = []
  # Using tf.function here and on _slow_window_closure caches tracing
  # of _fast_window and _slow_window, respectively, within a single
  # call to windowed sampling.  Why not annotate _fast_window and
  # _slow_window directly?  Two reasons:
  # - Caching across calls to windowed sampling is probably futile,
  #   because the trace function and bijector will be different Python
  #   objects, preventing cache hits.
  # - The cache of a global tf.function sticks around for the lifetime
  #   of the Python process, potentially leaking memory.
  @tf.function(autograph=False)
  def _fast_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           seed):
    return _fast_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  draws, trace, step_size, running_variances = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=first_window_size,
      initial_position=initial_transformed_position,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws = [[d] for d in draws]
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  @tf.function(autograph=False)
  def _slow_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           running_variances,
                           seed):
    return _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        initial_running_variance=running_variances,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    (draws, trace, step_size, running_variances, momentum_distribution
     ) = _slow_window_closure(
         proposal_kernel_kwargs=proposal_kernel_kwargs,
         window_size=window_size,
         initial_position=[d[-1] for d in draws],
         running_variances=running_variances,
         seed=slow_seed)
    for all_d, d in zip(all_draws, draws):
      all_d.append(d)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, _ = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=last_window_size,
      initial_position=[d[-1] for d in draws],
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  for all_d, d in zip(all_draws, draws):
    all_d.append(d)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=[d[-1] for d in draws],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(
              [tf.concat(d, axis=0) for d in all_draws]),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws])
      if no_trace:
        return all_states
      else:
        all_traces.append(trace)
        return sample.StatesAndTrace(
            all_states=all_states,
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))
Ejemplo n.º 19
0
def minimize(loss_fn,
             num_steps,
             optimizer,
             convergence_criterion=None,
             batch_convergence_reduce_fn=tf.reduce_all,
             trainable_variables=None,
             trace_fn=_trace_loss,
             return_full_length_trace=True,
             jit_compile=False,
             seed=None,
             name='minimize'):
    """Minimize a loss function using a provided optimizer.

  Args:
    loss_fn: Python callable with signature `loss = loss_fn()`, where `loss`
      is a `Tensor` loss to be minimized. This may optionally take a `seed`
      keyword argument, used to specify a per-iteration seed for stochastic
      loss functions (a stateless `Tensor` seed will be passed; see
      `tfp.random.sanitize_seed`).
    num_steps: Python `int` maximum number of steps to run the optimizer.
    optimizer: Optimizer instance to use. This may be a TF1-style
      `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python
      object that implements `optimizer.apply_gradients(grads_and_vars)`.
    convergence_criterion: Optional instance of
      `tfp.optimizer.convergence_criteria.ConvergenceCriterion`
      representing a criterion for detecting convergence. If `None`,
      the optimization will run for `num_steps` steps, otherwise, it will run
      for at *most* `num_steps` steps, as determined by the provided criterion.
      Default value: `None`.
    batch_convergence_reduce_fn: Python `callable` of signature
      `has_converged = batch_convergence_reduce_fn(batch_has_converged)`
      whose input is a `Tensor` of boolean values of the same shape as the
      `loss` returned by `loss_fn`, and output is a scalar
      boolean `Tensor`. This determines the behavior of batched
      optimization loops when `loss_fn`'s return value is non-scalar.
      For example, `tf.reduce_all` will stop the optimization
      once all members of the batch have converged, `tf.reduce_any` once *any*
      member has converged,
      `lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5` once more than
      half have converged, etc.
      Default value: `tf.reduce_all`.
    trainable_variables: list of `tf.Variable` instances to optimize with
      respect to. If `None`, defaults to the set of all variables accessed
      during the execution of `loss_fn()`.
      Default value: `None`.
    trace_fn: Python callable with signature `traced_values = trace_fn(
      traceable_quantities)`, where the argument is an instance of
      `tfp.math.MinimizeTraceableQuantities` and the returned `traced_values`
      may be a `Tensor` or nested structure of `Tensor`s. The traced values are
      stacked across steps and returned.
      The default `trace_fn` simply returns the loss. In general, trace
      functions may also examine the gradients, values of parameters,
      the state propagated by the specified `convergence_criterion`, if any (if
      no convergence criterion is specified, this will be `None`),
      as well as any other quantities captured in the closure of `trace_fn`,
      for example, statistics of a variational distribution.
      Default value: `lambda traceable_quantities: traceable_quantities.loss`.
    return_full_length_trace: Python `bool` indicating whether to return a trace
      of the full length `num_steps`, even if a convergence criterion stopped
      the optimization early, by tiling the value(s) traced at the final
      optimization step. This enables use in contexts such as XLA that require
      shapes to be known statically.
      Default value: `True`.
    jit_compile: If True, compiles the minimization loop using
      XLA. XLA performs compiler optimizations, such as fusion, and attempts to
      emit more efficient code. This may drastically improve the performance.
      See the docs for `tf.function`. (In JAX, this will apply `jax.jit`).
      Default value: `False`.
    seed: PRNG seed for stochastic losses; see `tfp.random.sanitize_seed.`
      Default value: `None`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: 'minimize'.

  Returns:
    trace: `Tensor` or nested structure of `Tensor`s, according to the
      return type of `trace_fn`. Each `Tensor` has an added leading dimension
      stacking the trajectory of the traced values over the course of the
      optimization. The size of this dimension is equal to `num_steps` if
      a convergence criterion was not specified and/or
      `return_full_length_trace=True`, and otherwise it is equal
      equal to the number of optimization steps taken.

  ### Examples

  To minimize the scalar function `(x - 5)**2`:

  ```python
  x = tf.Variable(0.)
  loss_fn = lambda: (x - 5.)**2
  losses = tfp.math.minimize(loss_fn,
                             num_steps=100,
                             optimizer=tf.optimizers.Adam(learning_rate=0.1))

  # In TF2/eager mode, the optimization runs immediately.
  print("optimized value is {} with loss {}".format(x, losses[-1]))
  ```

  In graph mode (e.g., inside of `tf.function` wrapping), retrieving any Tensor
  that depends on the minimization op will trigger the optimization:

  ```python
  with tf.control_dependencies([losses]):
    optimized_x = tf.identity(x)  # Use a dummy op to attach the dependency.
  ```

  We can attempt to automatically detect convergence and stop the optimization
  by passing an instance of
  `tfp.optimize.convergence_criteria.ConvergenceCriterion`. For example, to
  stop the optimization once a moving average of the per-step decrease in loss
  drops below `0.01`:

  ```python
  losses = tfp.math.minimize(
    loss_fn, num_steps=1000, optimizer=tf.optimizers.Adam(learning_rate=0.1),
    convergence_criterion=(
      tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))
  ```

  Here `num_steps=1000` defines an upper bound: the optimization will be
  stopped after 1000 steps even if no convergence is detected.

  In some cases, we may want to track additional context inside the
  optimization. We can do this by defining a custom `trace_fn`. Note that
  the `trace_fn` is passed the loss and gradients, as well as any auxiliary
  state maintained by the convergence criterion (if any), for example, moving
  averages of the loss or gradients, but it may also report the
  values of trainable parameters or other derived quantities by capturing them
  in its closure. For example, we can capture `x` and track its value over the
  optimization:

  ```python
  # `x` is the tf.Variable instance defined above.
  trace_fn = lambda traceable_quantities: {
    'loss': traceable_quantities.loss, 'x': x}
  trace = tfp.math.minimize(loss_fn, num_steps=100,
                            optimizer=tf.optimizers.Adam(0.1),
                            trace_fn=trace_fn)
  print(trace['loss'].shape,   # => [100]
        trace['x'].shape)      # => [100]
  ```

  When optimizing a batch of losses, some batch members will converge before
  others. The optimization will continue until the condition defined by the
  `batch_convergence_reduce_fn` becomes `True`. During these additional steps,
  converged elements will continue to be updated and may become unconverged.
  The convergence status of batch members can be diagnosed by tracing
  `has_converged`:

  ```python
  batch_size = 10
  x = tf.Variable([0.] * batch_size)
  trace_fn = lambda traceable_quantities: {
    'loss': traceable_quantities.loss,
    'has_converged': traceable_quantities.has_converged}
  trace = tfp.math.minimize(loss_fn, num_steps=100,
                            optimizer=tf.optimizers.Adam(0.1),,
                            trace_fn=trace_fn,
                            convergence_criterion=(
      tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))

  for i in range(batch_size):
    print('Batch element {} final state is {}converged.'
          ' It first converged at step {}.'.format(
          i, '' if has_converged[-1, i] else 'not ',
          np.argmax(trace.has_converged[:, i])))
  ```

  """

    if jit_compile:
        # Run the entire minimization inside a jit-compiled function. This is
        # typically faster than jit-compiling the individual steps.
        parameters = dict(locals())
        parameters['jit_compile'] = False

        @tf.function(autograph=False, jit_compile=True)
        def run_jitted_minimize():
            return minimize(**parameters)

        return run_jitted_minimize()

    def convergence_detected(step,
                             seed,
                             trace_arrays,
                             has_converged=None,
                             convergence_criterion_state=None):
        del step
        del seed
        del trace_arrays
        del convergence_criterion_state
        return (has_converged is not None  # Convergence criterion in use.
                and batch_convergence_reduce_fn(has_converged))

    # Main optimization routine.
    with tf.name_scope(name) as name:
        seed = samplers.sanitize_seed(seed, salt='minimize')

        # Take an initial training step to obtain the initial loss and values, which
        # will define the shape(s) of the `TensorArray`(s) that we create to hold
        # the results, and are used to initialize the convergence criterion.
        # This will trigger tf.function tracing of `optimizer_step_fn`, which is
        # then reused inside the training loop (i.e., it is only traced once).
        optimizer_step_fn = _make_optimizer_step_fn(
            loss_fn=loss_fn,
            optimizer=optimizer,
            trainable_variables=trainable_variables)
        initial_loss, initial_grads, initial_parameters = optimizer_step_fn(
            seed=seed)
        has_converged = None
        initial_convergence_criterion_state = None
        if convergence_criterion is not None:
            has_converged = tf.zeros(tf.shape(initial_loss), dtype=tf.bool)
            initial_convergence_criterion_state = convergence_criterion.bootstrap(
                initial_loss, initial_grads, initial_parameters)
        initial_traced_values = trace_fn(
            MinimizeTraceableQuantities(
                loss=initial_loss,
                gradients=initial_grads,
                parameters=initial_parameters,
                step=0,
                has_converged=has_converged,
                convergence_criterion_state=initial_convergence_criterion_state
            ))

        trace_arrays = _initialize_arrays(
            initial_values=initial_traced_values,
            num_steps=num_steps,
            truncate_at_convergence=(convergence_criterion is not None
                                     and not return_full_length_trace))

        # Run the optimization loop.
        with tf.control_dependencies([initial_loss]):
            potential_loop_vars = (1, seed, trace_arrays, has_converged,
                                   initial_convergence_criterion_state)
            results = tf.while_loop(
                cond=lambda *args: tf.logical_not(convergence_detected(*args)),  # pylint: disable=no-value-for-parameter
                body=_make_training_loop_body(
                    optimizer_step_fn=optimizer_step_fn,
                    convergence_criterion=convergence_criterion,
                    trace_fn=trace_fn),
                loop_vars=[x for x in potential_loop_vars if x is not None],
                parallel_iterations=1,
                maximum_iterations=num_steps - 1)
            indices, _, trace_arrays = results[:3]  # Guaranteed to be present.

            if convergence_criterion is not None and return_full_length_trace:
                # Fill out the trace by tiling the last written values.
                last_written_idx = tf.reduce_max(indices) - 1
                trace_arrays = tf.nest.map_structure(
                    lambda ta: _tile_last_written_value(ta, last_written_idx),
                    trace_arrays)

        return tf.nest.map_structure(lambda array: array.stack(), trace_arrays)
Ejemplo n.º 20
0
  def _sample_n(self, n, seed=None):
    loc, scale, low, high = self._loc_scale_low_high()
    batch_shape = self._batch_shape_tensor(
        loc=loc, scale=scale, low=low, high=high)
    sample_and_batch_shape = ps.concat([[n], batch_shape], 0)
    # TODO(b/162522020): Use this behavior unconditionally.
    if (tf.executing_eagerly() or
        not control_flow_util.GraphOrParentsInXlaContext(
            tf1.get_default_graph())):
      return tf.random.stateless_parameterized_truncated_normal(
          shape=sample_and_batch_shape,
          means=loc,
          stddevs=scale,
          minvals=low,
          maxvals=high,
          seed=samplers.sanitize_seed(seed))

    flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])
    # In order to be reparameterizable we sample on the truncated_normal of
    # unit variance and mean and scale (but with the standardized
    # truncation bounds).

    @tf.custom_gradient
    def _std_samples_with_gradients(lower, upper):
      """Standard truncated Normal with gradient support for low, high."""
      # Note: Unlike the convention in TFP, parameterized_truncated_normal
      # returns a tensor with the final dimension being the sample dimension.
      std_samples = random_ops.parameterized_truncated_normal(
          shape=flat_batch_and_sample_shape,
          means=0.0,
          stddevs=1.0,
          minvals=lower,
          maxvals=upper,
          dtype=self.dtype,
          seed=seed)

      def grad(dy):
        """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
        # std_samples has an extra dimension (the sample dimension), expand
        # lower and upper so they broadcast along this dimension.
        # See note above regarding parameterized_truncated_normal, the sample
        # dimension is the final dimension.
        lower_broadcast = lower[..., tf.newaxis]
        upper_broadcast = upper[..., tf.newaxis]

        cdf_samples = ((special_math.ndtr(std_samples) -
                        special_math.ndtr(lower_broadcast)) /
                       (special_math.ndtr(upper_broadcast) -
                        special_math.ndtr(lower_broadcast)))

        # tiny, eps are tolerance parameters to ensure we stay away from giving
        # a zero arg to the log CDF expression.

        tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
        eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
        cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

        du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                    tf.math.log(cdf_samples))
        dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                    tf.math.log1p(-cdf_samples))

        # Reduce the gradient across the samples
        grad_u = tf.reduce_sum(dy * du, axis=-1)
        grad_l = tf.reduce_sum(dy * dl, axis=-1)
        return [grad_l, grad_u]

      return std_samples, grad

    std_low, std_high = self._standardized_low_and_high(
        low=low, high=high, loc=loc, scale=scale)
    low_high_shp = tf.broadcast_dynamic_shape(
        tf.shape(std_low), tf.shape(std_high))
    std_low = tf.broadcast_to(std_low, low_high_shp)
    std_high = tf.broadcast_to(std_high, low_high_shp)

    std_samples = _std_samples_with_gradients(
        tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

    # The returned shape is [flat_batch x n]
    std_samples = tf.transpose(std_samples, perm=[1, 0])

    std_samples = tf.reshape(std_samples, sample_and_batch_shape)
    return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
Ejemplo n.º 21
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = bu.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_negative_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_negative_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_negative_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to(
                inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = bu.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
                potential_energy=-untempered_negative_energy_ignoring_ulp,
            )

            return states, post_swap_kernel_results
Ejemplo n.º 22
0
 def _sample_n(self, n, seed, **kwargs):
     seed = samplers.sanitize_seed(seed, salt='sharded_sample')
     seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32))
     return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
Ejemplo n.º 23
0
def sample_chain(
    num_results,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    num_burnin_steps=0,
    num_steps_between_results=0,
    trace_fn=lambda current_state, kernel_results: kernel_results,
    return_final_kernel_results=False,
    parallel_iterations=10,
    seed=None,
    name=None,
):
    """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from an Markov chain at `current_state` and whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  This function can sample from multiple chains, in parallel.  Whether or not
  there are multiple chains is dictated by how the `kernel` treats its inputs.
  Typically, the shape of the independent chains is shape of the result of the
  `target_log_prob_fn` used by the `kernel` when applied to the given
  `current_state`.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such 'thinning'
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized, and thus do not
  increase memory requirements.

  In addition to returning the chain state, this function supports tracing of
  auxiliary variables used by the kernel. The traced values are selected by
  specifying `trace_fn`. By default, all kernel results are traced but in the
  future the default will be changed to no results being traced, so plan
  accordingly. See below for some examples of this feature.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s
      representing internal calculations made within the previous call to this
      function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    trace_fn: A callable that takes in the current chain state and the previous
      kernel results and return a `Tensor` or a nested collection of `Tensor`s
      that is then traced along with the chain state.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state and the trace specified by the
      `trace_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'mcmc_sample_chain').

  Returns:
    checkpointable_states_and_trace: if `return_final_kernel_results` is
      `True`. The return value is an instance of
      `CheckpointableStatesAndTrace`.
    all_states: if `return_final_kernel_results` is `False` and `trace_fn` is
      `None`. The return value is a `Tensor` or Python list of `Tensor`s
      representing the state(s) of the Markov chain(s) at each result step. Has
      same shape as input `current_state` but with a prepended
      `num_results`-size dimension.
    states_and_trace: if `return_final_kernel_results` is `False` and
      `trace_fn` is not `None`. The return value is an instance of
      `StatesAndTrace`.

  #### Examples

  ##### Sample from a diagonal-variance Gaussian.

  I.e.,

  ```none
  for i=1..n:
    x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev))  # likelihood
  ```

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  dims = 10
  true_stddev = tf.sqrt(tf.linspace(1., 3., dims))
  likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev)

  states = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros(dims),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=None)

  sample_mean = tf.reduce_mean(states, axis=0)
  # ==> approx all zeros

  sample_stddev = tf.sqrt(tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0))
  # ==> approx equal true_stddev
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  # prior
  w ~ MultivariateNormal(loc=0, scale=eye(d))
  for i=1..n:
    # likelihood
    x[i] ~ Normal(loc=w^T F[i], scale=1)
  ```

  where `F` denotes factors.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  # Specify model.
  def make_prior(dims):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.matmul(weights, factors, adjoint_b=True))

  def joint_log_prob(num_weights, factors, x, w):
    return (make_prior(num_weights).log_prob(w) +
            make_likelihood(w, factors).log_prob(x))

  def unnormalized_log_posterior(w):
    # Posterior is proportional to: `p(W, X=x | factors)`.
    return joint_log_prob(num_weights, factors, x, w)

  # Setup data.
  num_weights = 10 # == d
  num_factors = 40 # == n
  num_chains = 100

  weights = make_prior(num_weights).sample(1)
  factors = tf.random.normal([num_factors, num_weights])
  x = make_likelihood(weights, factors).sample()

  # Sample from Hamiltonian Monte Carlo Markov Chain.

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros([num_chains, num_weights], name='init_weights'),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_log_posterior,
        step_size=0.1,
        num_leapfrog_steps=2))

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  # ==> approx equal to weights

  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  # ==> less than 1
  ```

  ##### Custom tracing functions.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  likelihood = tfd.Normal(loc=0., scale=1.)

  def sample_chain(trace_fn):
    return tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=0.,
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=trace_fn)

  def trace_log_accept_ratio(states, previous_kernel_results):
    return previous_kernel_results.log_accept_ratio

  def trace_everything(states, previous_kernel_results):
    return previous_kernel_results

  _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio)
  _, kernel_results = sample_chain(trace_fn=trace_everything)

  acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio, 0.))
  # Equivalent to, but more efficient than:
  acceptance_prob = tf.math.exp(tf.minimum(
      kernel_results.log_accept_ratio, 0.))
  ```

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='mcmc.sample_chain')

    if not kernel.is_calibrated:
        warnings.warn(
            'supplied `TransitionKernel` is not calibrated. Markov '
            'chain may not converge to intended target distribution.')
    with tf.name_scope(name or 'mcmc_sample_chain'):
        num_results = ps.convert_to_shape_tensor(num_results,
                                                 dtype=tf.int32,
                                                 name='num_results')
        num_burnin_steps = ps.convert_to_shape_tensor(num_burnin_steps,
                                                      dtype=tf.int32,
                                                      name='num_burnin_steps')
        num_steps_between_results = ps.convert_to_shape_tensor(
            num_steps_between_results,
            dtype=tf.int32,
            name='num_steps_between_results')
        current_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name='current_state'),
            current_state)
        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        if trace_fn is None:
            # It simplifies the logic to use a dummy function here.
            trace_fn = lambda *args: ()
            no_trace = True
        else:
            no_trace = False
        if trace_fn is sample_chain.__defaults__[4]:
            warnings.warn(
                'Tracing all kernel results by default is deprecated. Set '
                'the `trace_fn` argument to None (the future default '
                'value) or an explicit callback that traces the values '
                'you are interested in.')

        def _seeded_one_step(seed, *state_and_results):
            step_seed, passalong_seed = (samplers.split_seed(seed)
                                         if is_seeded else (None, seed))
            one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
            return [passalong_seed] + list(
                kernel.one_step(*state_and_results, **one_step_kwargs))

        def _trace_scan_fn(seed_state_and_results, num_steps):
            seed, next_state, current_kernel_results = loop_util.smart_for_loop(
                loop_num_iter=num_steps,
                body_fn=_seeded_one_step,
                initial_loop_vars=list(seed_state_and_results),
                parallel_iterations=parallel_iterations)
            return seed, next_state, current_kernel_results

        (_, _,
         final_kernel_results), (all_states, trace) = loop_util.trace_scan(
             loop_fn=_trace_scan_fn,
             initial_state=(seed, current_state, previous_kernel_results),
             elems=tf.one_hot(indices=0,
                              depth=num_results,
                              on_value=1 + num_burnin_steps,
                              off_value=1 + num_steps_between_results,
                              dtype=tf.int32),
             # pylint: disable=g-long-lambda
             trace_fn=lambda seed_state_and_results: (seed_state_and_results[
                 1], trace_fn(*seed_state_and_results[1:])),
             # pylint: enable=g-long-lambda
             parallel_iterations=parallel_iterations)

        if return_final_kernel_results:
            return CheckpointableStatesAndTrace(
                all_states=all_states,
                trace=trace,
                final_kernel_results=final_kernel_results)
        else:
            if no_trace:
                return all_states
            else:
                return StatesAndTrace(all_states=all_states, trace=trace)
Ejemplo n.º 24
0
def step_kernel(
    num_steps,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    return_final_kernel_results=False,
    parallel_iterations=10,
    seed=None,
    name=None,
):
    """Takes `num_steps` repeated `TransitionKernel` steps from `current_state`.

  This is meant to be a minimal driver for executing `TransitionKernel`s; for
  something more featureful, see `sample_chain`.

  Args:
    num_steps: Integer number of Markov chain steps.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s.
      Warm-start for the auxiliary state needed by the given `kernel`.
      If not supplied, `step_kernel` will cold-start with
      `kernel.bootstrap_results`.
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state after `num_steps` steps are taken.
      This can be useful to inspect the final auxiliary state, or for a later
      warm restart.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'mcmc_step_kernel').

  Returns:
    next_state: Markov chain state after `num_step` steps are taken, of
      identical type as `current_state`.
    final_kernel_results: kernel results, as supplied by `kernel.one_step` after
      `num_step` steps are taken. This is only returned if
      `return_final_kernel_results` is `True`.
  """
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='experimental.mcmc.step_kernel')

    if not kernel.is_calibrated:
        warnings.warn(
            'supplied `TransitionKernel` is not calibrated. Markov '
            'chain may not converge to intended target distribution.')
    with tf.name_scope(name or 'mcmc_step_kernel'):
        num_steps = tf.convert_to_tensor(num_steps,
                                         dtype=tf.int32,
                                         name='num_steps')
        current_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name='current_state'),
            current_state)
        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        def _seeded_one_step(seed, *state_and_results):
            step_seed, passalong_seed = (samplers.split_seed(seed)
                                         if is_seeded else (None, seed))
            one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
            return [passalong_seed] + list(
                kernel.one_step(*state_and_results, **one_step_kwargs))

        _, next_state, final_kernel_results = mcmc_util.smart_for_loop(
            loop_num_iter=num_steps,
            body_fn=_seeded_one_step,
            initial_loop_vars=list(
                (seed, current_state, previous_kernel_results)),
            parallel_iterations=parallel_iterations)

        # return semantics are simple enough to not warrant the use of named tuples
        # as in `sample_chain`
        if return_final_kernel_results:
            return next_state, final_kernel_results
        else:
            return next_state
Ejemplo n.º 25
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            seeds = samplers.split_seed(seed, n=len(current_state_parts))
            seeds = distribute_lib.fold_in_axis_index(
                seeds, self.experimental_shard_axis_names)

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=ps.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = ps.rank(current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts,
                    next_momentum_parts,
                    independent_chain_ndims,
                    shard_axis_names=self.experimental_shard_axis_names),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Ejemplo n.º 26
0
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc(
      joint_dist, n_chains=n_chains, seed=setup_seed, **pins)

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)
  # If we (over) optimistically assume good scaling, this will be near the
  # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts,
  # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for
  # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv.
  # http://arxiv.org/abs/1905.09813.
  init_step_size = tf.cast(
      ps.shape(initial_transformed_position)[-1], tf.float32) ** -0.25

  all_draws = []
  all_traces = []
  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': tf.fill([n_chains, 1], init_step_size),
      'momentum_distribution': _init_momentum(initial_transformed_position),
  })
  draws, trace, step_size, running_variance = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=first_window_size,
      initial_position=initial_transformed_position,
      bijector=bijector,
      trace_fn=trace_fn,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws.append(draws)
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    draws, trace, step_size, running_variance, momentum_distribution = _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=draws[-1],
        initial_running_variance=running_variance,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=slow_seed)
    all_draws.append(draws)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, running_variance = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=last_window_size,
      initial_position=draws[-1],
      bijector=bijector,
      trace_fn=trace_fn,
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  all_draws.append(draws)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=draws[-1],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      all_draws.append(draws)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      all_draws.append(draws)
      all_traces.append(trace)
      if no_trace:
        return bijector.inverse(tf.concat(all_draws, axis=0))
      else:
        return sample.StatesAndTrace(
            all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))
Ejemplo n.º 27
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        start_trajectory_seed, loop_seed = samplers.split_seed(seed)

        with tf.name_scope(self.name + '.one_step'):
            state_structure = current_state
            current_state = tf.nest.flatten(current_state)
            if (tf.nest.is_nested(state_structure)
                    and (not mcmc_util.is_list_like(state_structure)
                         or len(current_state) != len(state_structure))):
                # TODO(b/170865194): Support dictionaries and other non-list-like state.
                raise TypeError(
                    'NUTS does not currently support nested or '
                    'non-list-like state structures (saw: {}).'.format(
                        state_structure))

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob,
                                                seed=start_trajectory_seed)

            def _copy(v):
                return v * ps.ones(ps.pad(
                    [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                                   dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, seed, state, metastate: self.
                _loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate, seed),
                loop_vars=(tf.zeros([], dtype=tf.int32,
                                    name='iter'), loop_seed,
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy,
                seed=seed,
            )

            result_state = tf.nest.pack_sequence_as(
                state_structure, new_step_metastate.candidate_state.state)
            return result_state, kernel_results
Ejemplo n.º 28
0
 def _sample_n(self, n, seed, **kwargs):
   seed = samplers.sanitize_seed(seed, salt='sharded_sample')
   seed = distribute_lib.fold_in_axis_index(
       seed, self.experimental_shard_axis_names)
   return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional seed for reproducible sampling.

    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = samplers.sanitize_seed(seed)
                proposal_seed, resample_seed = samplers.split_seed(seed)

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=proposal_seed)
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: tf.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=resample_seed)
                uniform_weights = tf.fill(
                    ps.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood),
                    seed=seed))
Ejemplo n.º 30
0
 def test_sanitize_none(self):
     seed1 = samplers.sanitize_seed(seed=None)
     seed2 = samplers.sanitize_seed(seed=None)
     self.assertNotAllEqual(seed1, seed2)