def testPreconditionedHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([1., 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda def kernel(hmc_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_state.state_extra[0] if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def testBasicHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = tf.constant([2., 3.]) base_scale = tf.constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum(0.5 * tf.square( (x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_extra if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal( shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = fun_mcmc.prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=_test_seed()) fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init( tf.zeros([1]), target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def trace(): # pylint: disable=g-long-lambda kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=self._make_seed(tfp_test_util.test_seed())) fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init( state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def testPreconditionedHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=_test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=_test_seed()) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample(4096, seed=_test_seed()) return chain, log_accept_ratio_trace, true_samples
def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, aggregation): # We'll use HMC as the source of our chain. # While HMC is being sampled, we also compute the running autocovariance. step_size = 0.2 num_steps = 1000 num_leapfrog_steps = 10 max_lags = 300 state = tf.constant(np.zeros(state_shape).astype(np.float32)) def target_log_prob_fn(x): lp = -0.5 * tf.square(x) if event_ndims is None: return lp, () else: return tf.reduce_sum(lp, -1), () def kernel(hmc_state, raac_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step( raac_state, hmc_state.state, axis=aggregation) return (hmc_state, raac_state, seed), hmc_extra if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=( fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.running_approximate_auto_covariance_init( max_lags=max_lags, state_shape=state_shape, dtype=state.dtype, axis=aggregation), seed, ), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_variance = np.array( tf.math.reduce_variance(np.array(chain), true_aggregation)) true_autocov = np.array( tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags)) if aggregation is not None: true_autocov = tf.reduce_mean( true_autocov, [a + 1 for a in util.flatten_tree(aggregation)]) self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5) self.assertAllClose( true_autocov, raac_state.auto_covariance / raac_state.auto_covariance[0], atol=0.1)
def train_p(q, u, x_pos, step_size, opt_p): """Train P using the standard CD objective. Args: q: `ModelQ`. u: A callable representing the energy function. x_pos: A batch of positive examples. step_size: Step size to use for HMC. opt_p: A `tf.optimizer.Optimizer`. Returns: x_neg_q: Negative samples sampled from `q`. x_neg_p: Negative samples used to train `p`, possibly generated via HMC. p_accept: Acceptance rate of HMC. step_size: The new step size, possibly adapted to adjust the acceptance rate. pos_e: Mean energy of the positive samples across the batch. pos_e: Mean energy of the positive samples across the batch, after the parameter update. neg_e_q: Mean energy of `x_neg_q` across the batch. neg_e_p: Mean energy of `x_neg_p` across the batch. neg_e_p_updated: Mean energy of `x_neg_p` across the batch, after the parameter update. """ def create_momentum_sample_fn(state): sample_fn = lambda seed: tf.random.normal( # pylint: disable=g-long-lambda tf.shape(state), stddev=FLAGS.mcmc_momentum_stddev) return sample_fn _, x_neg_q, _ = q.sample_with_log_prob(FLAGS.batch_size, temp=FLAGS.q_temperature) neg_e_q = tf.reduce_mean(u(x_neg_q)) def p_log_prob(x): return -u(x) if FLAGS.use_mcmc: def log_prob_non_transformed(x): p_log_p = p_log_prob(x) return p_log_p, (x, ) # TODO(siege): Why aren't we actually using NeuTra? # def log_prob_transformed(z): # x, logdet = q.reverse(z) # p_log_p = p_log_prob(x) # return p_log_p + logdet, (x,) def kernel(hmc_state, step_size, step): """HMC kernel.""" hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=FLAGS.mcmc_leapfrog_steps, momentum_sample_fn=create_momentum_sample_fn(hmc_state.state), target_log_prob_fn=log_prob_non_transformed) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) if FLAGS.mcmc_adapt_step_size: step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9) return (hmc_state, step_size, step + 1), hmc_extra hmc_state, is_accepted = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init( x_neg_q, log_prob_non_transformed), step_size, 0), fn=kernel, num_steps=FLAGS.mcmc_num_steps, trace_fn=lambda _, hmc_extra: hmc_extra.is_accepted) x_neg_p = hmc_state[0].state_extra[0] step_size = hmc_state[1] p_accept = tf.reduce_mean(tf.cast(is_accepted, tf.float32)) else: x_neg_p = x_neg_q p_accept = 0.0 step_size = 0.0 with tf.GradientTape() as tape: tape.watch(u.trainable_variables) pos_e = tf.reduce_mean(u(x_pos)) neg_e_p = tf.reduce_mean(u(x_neg_p)) loss = pos_e - neg_e_p + tf.square(pos_e) * FLAGS.p_center_regularizer variables = u.trainable_variables grads = tape.gradient(loss, variables) grads_and_vars = list(zip(grads, variables)) opt_p.apply_gradients(grads_and_vars) pos_e_updated = tf.reduce_mean(u(x_pos)) neg_e_p_updated = tf.reduce_mean(u(x_neg_p)) return (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q, neg_e_p, neg_e_p_updated)