def adaptive_hamiltonian_monte_carlo_init( state: 'fun_mc.TensorNest', target_log_prob_fn: 'fun_mc.PotentialFn', step_size: 'fun_mc.FloatTensor' = 1e-2, initial_mean: 'fun_mc.FloatNest' = 0., initial_scale: 'fun_mc.FloatNest' = 1., scale_smoothing_steps: 'fun_mc.IntTensor' = 10, ) -> 'AdaptiveHamiltonianMonteCarloState': """Initializes `AdaptiveHamiltonianMonteCarloState`. Args: state: Initial state of the chain. target_log_prob_fn: Target log prob fn. step_size: Initial scalar step size. initial_mean: Initial mean for computing the running variance estimate. Must broadcast structurally and tensor-wise with state. initial_scale: Initial scale for computing the running variance estimate. Must broadcast structurally and tensor-wise with state. scale_smoothing_steps: How much weight to assign to the `initial_mean` and `initial_scale`. Increase this to stabilize early adaptation. Returns: adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step` `TransitionOperator`. """ hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn) dtype = util.flatten_tree(hmc_state.state)[0].dtype chain_ndims = len(hmc_state.target_log_prob.shape) running_var_state = fun_mc.running_variance_init( shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state), dtype=util.map_tree(lambda s: s.dtype, hmc_state.state), ) initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state) initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state) # It's important to add some smoothing here, as initial updates can be very # different than the stationary distribution. # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the # internals here. running_var_state = running_var_state._replace( num_points=util.map_tree( lambda p: ( # pylint: disable=g-long-lambda int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast( scale_smoothing_steps, p.dtype)), running_var_state.num_points), mean=util.map_tree( # pylint: disable=g-long-lambda lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean, initial_mean), variance=util.map_tree( # pylint: disable=g-long-lambda lambda v, init_s: tf.ones_like(v) * init_s**2, running_var_state.variance, initial_scale), ) log_step_size_opt_state = fun_mc.adam_init( tf.math.log(tf.convert_to_tensor(step_size, dtype=dtype))) return AdaptiveHamiltonianMonteCarloState( hmc_state=hmc_state, running_var_state=running_var_state, log_step_size_opt_state=log_step_size_opt_state, step=tf.zeros([], tf.int32))
def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def testAdam(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.adam_init([self._constant(0.), self._constant(0.)]), lambda adam_state: fun_mcmc.adam_step( # pylint: disable=g-long-lambda adam_state, loss_fn, learning_rate=self._constant(0.01)), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)