Beispiel #1
0
def stochastic_gradient_ascent_hmc_init(
    state: 'fun_mc.State',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    init_trajectory_length: 'fun_mc.FloatTensor',
    trajectory_length_params_init_fn:
    'Callable[[fun_mc.FloatTensor], Any]' = default_trajectory_length_init):
  """Initialize Stochastic Gradient Ascent HMC state.

  Args:
    state: Initial Markov Chain state.
    target_log_prob_fn: Target log prob fn.
    init_trajectory_length: Initial trajectory length. Passed to
      `trajectory_length_params_init_fn`.
    trajectory_length_params_init_fn: Initializer for the trajectory length
      parameters.

  Returns:
    sga_hmc_state: New Stochastic Gradient Ascent HMC state.
  """
  init_trajectory_length = tf.convert_to_tensor(init_trajectory_length)
  init_trajectory_length_params = trajectory_length_params_init_fn(
      init_trajectory_length)
  return StochasticGradientAscentHMCState(
      hmc_state=fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
      step=tf.ones([], tf.int32),
      trajectory_length_params_opt_state=fun_mc.adam_init(
          init_trajectory_length_params),
      trajectory_length_params_rmean_state=fun_mc.running_mean_init(
          util.map_tree(lambda x: x.shape, init_trajectory_length_params),
          util.map_tree(lambda x: x.dtype, init_trajectory_length_params),
      )._replace(mean=init_trajectory_length_params),
  )
Beispiel #2
0
def step_size_adaptation_init(
    init_step_size: 'fun_mc.FloatTensor') -> 'StepSizeAdaptationState':
  """Initializes `StepSizeAdaptationState`.

  Args:
    init_step_size: Floating point Tensor. Initial step size.

  Returns:
    step_size_adaptation_state: `StepSizeAdaptationState`
  """
  init_step_size = tf.convert_to_tensor(init_step_size)
  rms_state = fun_mc.running_mean_init(init_step_size.shape,
                                       init_step_size.dtype)
  rms_state = rms_state._replace(mean=init_step_size)

  return StepSizeAdaptationState(
      step=tf.constant(0, tf.int32),
      opt_state=fun_mc.adam_init(tf.math.log(init_step_size)),
      rms_state=rms_state,
  )