Beispiel #1
0
def adaptive_hamiltonian_monte_carlo_init(
    state: 'fun_mc.TensorNest',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    step_size: 'fun_mc.FloatTensor' = 1e-2,
    initial_mean: 'fun_mc.FloatNest' = 0.,
    initial_scale: 'fun_mc.FloatNest' = 1.,
    scale_smoothing_steps: 'fun_mc.IntTensor' = 10,
) -> 'AdaptiveHamiltonianMonteCarloState':
  """Initializes `AdaptiveHamiltonianMonteCarloState`.

  Args:
    state: Initial state of the chain.
    target_log_prob_fn: Target log prob fn.
    step_size: Initial scalar step size.
    initial_mean: Initial mean for computing the running variance estimate. Must
      broadcast structurally and tensor-wise with state.
    initial_scale: Initial scale for computing the running variance estimate.
      Must broadcast structurally and tensor-wise with state.
    scale_smoothing_steps: How much weight to assign to the `initial_mean` and
      `initial_scale`. Increase this to stabilize early adaptation.

  Returns:
    adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step`
      `TransitionOperator`.
  """
  hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn)
  dtype = util.flatten_tree(hmc_state.state)[0].dtype
  chain_ndims = len(hmc_state.target_log_prob.shape)
  running_var_state = fun_mc.running_variance_init(
      shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state),
      dtype=util.map_tree(lambda s: s.dtype, hmc_state.state),
  )
  initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state)
  initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state)

  # It's important to add some smoothing here, as initial updates can be very
  # different than the stationary distribution.
  # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the
  # internals here.
  running_var_state = running_var_state._replace(
      num_points=util.map_tree(
          lambda p: (  # pylint: disable=g-long-lambda
              int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast(
                  scale_smoothing_steps, p.dtype)),
          running_var_state.num_points),
      mean=util.map_tree(  # pylint: disable=g-long-lambda
          lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean,
          initial_mean),
      variance=util.map_tree(  # pylint: disable=g-long-lambda
          lambda v, init_s: tf.ones_like(v) * init_s**2,
          running_var_state.variance, initial_scale),
  )
  ssa_state = step_size_adaptation_init(
      tf.convert_to_tensor(step_size, dtype=dtype))

  return AdaptiveHamiltonianMonteCarloState(
      hmc_state=hmc_state,
      running_var_state=running_var_state,
      ssa_state=ssa_state,
      step=tf.zeros([], tf.int32))
Beispiel #2
0
def stochastic_gradient_ascent_hmc_init(
    state: 'fun_mc.State',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    init_trajectory_length: 'fun_mc.FloatTensor',
    trajectory_length_params_init_fn:
    'Callable[[fun_mc.FloatTensor], Any]' = default_trajectory_length_init):
  """Initialize Stochastic Gradient Ascent HMC state.

  Args:
    state: Initial Markov Chain state.
    target_log_prob_fn: Target log prob fn.
    init_trajectory_length: Initial trajectory length. Passed to
      `trajectory_length_params_init_fn`.
    trajectory_length_params_init_fn: Initializer for the trajectory length
      parameters.

  Returns:
    sga_hmc_state: New Stochastic Gradient Ascent HMC state.
  """
  init_trajectory_length = tf.convert_to_tensor(init_trajectory_length)
  init_trajectory_length_params = trajectory_length_params_init_fn(
      init_trajectory_length)
  return StochasticGradientAscentHMCState(
      hmc_state=fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
      step=tf.ones([], tf.int32),
      trajectory_length_params_opt_state=fun_mc.adam_init(
          init_trajectory_length_params),
      trajectory_length_params_rmean_state=fun_mc.running_mean_init(
          util.map_tree(lambda x: x.shape, init_trajectory_length_params),
          util.map_tree(lambda x: x.dtype, init_trajectory_length_params),
      )._replace(mean=init_trajectory_length_params),
  )
Beispiel #3
0
        def hmc_step(trajectory_length, axis_name=()):
            @tfp.experimental.distribute.JointDistributionCoroutine
            def model():
                z = yield root(tfd.Normal(0., 1))
                yield tfp.experimental.distribute.Sharded(
                    tfd.Sample(tfd.Normal(z, 1.), 8), axis_name)

            @tfp.experimental.distribute.JointDistributionCoroutine
            def momentum_dist():
                yield root(tfd.Normal(0., 2))
                yield root(
                    tfp.experimental.distribute.Sharded(
                        tfd.Sample(tfd.Normal(0., 3.), 8), axis_name))

            def target_log_prob_fn(x):
                return model.log_prob(x), ()

            def kinetic_energy_fn(m):
                return -momentum_dist.log_prob(m), ()

            def momentum_sample_fn(seed):
                return momentum_dist.sample(2, seed=seed)

            state = model.sample(2, seed=seed)
            hmc_state = fun_mc.hamiltonian_monte_carlo_init(
                state, target_log_prob_fn)
            hmc_state, hmc_extra = (
                prefab.hamiltonian_monte_carlo_with_state_grads_step(
                    hmc_state,
                    trajectory_length=trajectory_length,
                    scalar_step_size=epsilon,
                    step_size_scale=util.map_tree(lambda x: 1. + tf.abs(x),
                                                  state),
                    target_log_prob_fn=target_log_prob_fn,
                    seed=seed,
                    kinetic_energy_fn=kinetic_energy_fn,
                    momentum_sample_fn=momentum_sample_fn,
                    shard_axis_names=model.experimental_shard_axis_names))

            def sum_state(x, axis_name):
                res = tf.reduce_sum(x**2)
                if axis_name:
                    res = backend.distribute_lib.psum(res, axis_name)
                return res

            sum_sq = util.map_tree_up_to(hmc_extra.proposed_state, sum_state,
                                         hmc_extra.proposed_state,
                                         model.experimental_shard_axis_names)
            sum_sq = sum(util.flatten_tree(sum_sq))
            return sum_sq, ()