def _chain_gets_correct_expectations(self, x, event_dims, sess, feed_dict=None): def log_gamma_log_prob(x): return self._log_gamma_log_prob(x, event_dims) step_size = array_ops.placeholder(np.float32, [], name='step_size') hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') if feed_dict is None: feed_dict = {} feed_dict.update({step_size: 0.1, hmc_lf_steps: 2, hmc_n_steps: 300}) sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], step_size, hmc_lf_steps, x, log_gamma_log_prob, event_dims) acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], feed_dict) samples = samples[feed_dict[hmc_n_steps] // 2:] expected_x_est = samples.mean() expected_exp_x_est = np.exp(samples).mean() logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( self._expected_x, self._expected_exp_x)) logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( expected_x_est, expected_exp_x_est)) self.assertNear(expected_x_est, self._expected_x, 2e-2) self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) self.assertTrue((acceptance_probs > 0.5).all()) self.assertTrue((acceptance_probs <= 1.0).all())
def testChainWorksIn16Bit(self): def log_prob(x): return - math_ops.reduce_sum(x * x, axis=-1) states, acceptance_probs = hmc.chain( n_iterations=10, step_size=np.float16(0.01), n_leapfrog_steps=10, initial_x=np.zeros(5).astype(np.float16), target_log_prob_fn=log_prob, event_dims=[-1]) with self.test_session() as sess: states_, acceptance_probs_ = sess.run([states, acceptance_probs]) self.assertEqual(np.float16, states_.dtype) self.assertEqual(np.float16, acceptance_probs_.dtype)
def testChainWorksIn16Bit(self): def log_prob(x): return -math_ops.reduce_sum(x * x, axis=-1) states, acceptance_probs = hmc.chain(n_iterations=10, step_size=np.float16(0.01), n_leapfrog_steps=10, initial_x=np.zeros(5).astype( np.float16), target_log_prob_fn=log_prob, event_dims=[-1]) with self.test_session() as sess: states_, acceptance_probs_ = sess.run([states, acceptance_probs]) self.assertEqual(np.float16, states_.dtype) self.assertEqual(np.float16, acceptance_probs_.dtype)