def testSampleChainSeedReproducibleWorksCorrectly(self): with self.test_session(graph=ops.Graph()) as sess: num_results = 10 independent_chain_ndims = 1 def log_gamma_log_prob(x): event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) kwargs = dict( target_log_prob_fn=log_gamma_log_prob, current_state=np.random.rand(4, 3, 2), step_size=0.1, num_leapfrog_steps=2, num_burnin_steps=150, seed=52, ) samples0, kernel_results0 = hmc.sample_chain(**dict( list(kwargs.items()) + list( dict(num_results=2 * num_results, num_steps_between_results=0).items()))) samples1, kernel_results1 = hmc.sample_chain(**dict( list(kwargs.items()) + list( dict(num_results=num_results, num_steps_between_results=1).items()))) [ samples0_, samples1_, target_log_prob0_, target_log_prob1_, ] = sess.run([ samples0, samples1, kernel_results0.current_target_log_prob, kernel_results1.current_target_log_prob, ]) self.assertAllClose(samples0_[::2], samples1_, atol=1e-5, rtol=1e-5) self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, atol=1e-5, rtol=1e-5)
def testStateParts(self): with self.test_session() as sess: dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) dist_y = independent_lib.Independent( gamma_lib.Gamma(concentration=self.dtype([1, 2]), rate=self.dtype([0.5, 0.75])), reinterpreted_batch_ndims=1) def target_log_prob(x, y): return dist_x.log_prob(x) + dist_y.log_prob(y) x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] samples, _ = hmc.sample_chain( num_results=int(2e3), target_log_prob_fn=target_log_prob, current_state=x0, step_size=0.85, num_leapfrog_steps=3, num_burnin_steps=int(250), seed=49) actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] actual_vars = [_reduce_variance(s, axis=0) for s in samples] expected_means = [dist_x.mean(), dist_y.mean()] expected_vars = [dist_x.variance(), dist_y.variance()] [ actual_means_, actual_vars_, expected_means_, expected_vars_, ] = sess.run([ actual_means, actual_vars, expected_means, expected_vars, ]) self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30)
def testStateParts(self): with self.test_session(graph=ops.Graph()) as sess: dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) dist_y = independent_lib.Independent( gamma_lib.Gamma(concentration=self.dtype([1, 2]), rate=self.dtype([0.5, 0.75])), reinterpreted_batch_ndims=1) def target_log_prob(x, y): return dist_x.log_prob(x) + dist_y.log_prob(y) x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] samples, _ = hmc.sample_chain( num_results=int(2e3), target_log_prob_fn=target_log_prob, current_state=x0, step_size=0.85, num_leapfrog_steps=3, num_burnin_steps=int(250), seed=49) actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] actual_vars = [_reduce_variance(s, axis=0) for s in samples] expected_means = [dist_x.mean(), dist_y.mean()] expected_vars = [dist_x.variance(), dist_y.variance()] [ actual_means_, actual_vars_, expected_means_, expected_vars_, ] = sess.run([ actual_means, actual_vars, expected_means, expected_vars, ]) self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
def _chain_gets_correct_expectations(self, x, independent_chain_ndims, sess, feed_dict=None): counter = collections.Counter() def log_gamma_log_prob(x): counter["target_calls"] += 1 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) num_results = array_ops.placeholder(np.int32, [], name="num_results") step_size = array_ops.placeholder(np.float32, [], name="step_size") num_leapfrog_steps = array_ops.placeholder(np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} feed_dict.update({ num_results: 150, step_size: 0.05, num_leapfrog_steps: 2 }) samples, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_gamma_log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, num_burnin_steps=150, seed=42) self.assertAllEqual(dict(target_calls=2), counter) expected_x = (math_ops.digamma(self._shape_param) - np.log(self._rate_param)) expected_exp_x = self._shape_param / self._rate_param log_accept_ratio_, samples_, expected_x_ = sess.run( [kernel_results.log_accept_ratio, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog( 1, "True E[x, exp(x)]: {}\t{}".format(expected_x_, expected_exp_x)) logging_ops.vlog( 1, "Estimated E[x, exp(x)]: {}\t{}".format(actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs > 0.5) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs <= 1.)
def testSampleChainSeedReproducibleWorksCorrectly(self): with self.test_session(graph=ops.Graph()) as sess: num_results = 10 independent_chain_ndims = 1 def log_gamma_log_prob(x): event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) kwargs = dict( target_log_prob_fn=log_gamma_log_prob, current_state=np.random.rand(4, 3, 2), step_size=0.1, num_leapfrog_steps=2, num_burnin_steps=150, seed=52, ) samples0, kernel_results0 = hmc.sample_chain( **dict(list(kwargs.items()) + list(dict( num_results=2 * num_results, num_steps_between_results=0).items()))) samples1, kernel_results1 = hmc.sample_chain( **dict(list(kwargs.items()) + list(dict( num_results=num_results, num_steps_between_results=1).items()))) [ samples0_, samples1_, target_log_prob0_, target_log_prob1_, ] = sess.run([ samples0, samples1, kernel_results0.current_target_log_prob, kernel_results1.current_target_log_prob, ]) self.assertAllClose(samples0_[::2], samples1_, atol=1e-5, rtol=1e-5) self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, atol=1e-5, rtol=1e-5)
def _chain_gets_correct_expectations(self, x, independent_chain_ndims, sess, feed_dict=None): counter = collections.Counter() def log_gamma_log_prob(x): counter["target_calls"] += 1 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) num_results = array_ops.placeholder( np.int32, [], name="num_results") step_size = array_ops.placeholder( np.float32, [], name="step_size") num_leapfrog_steps = array_ops.placeholder( np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} feed_dict.update({num_results: 150, step_size: 0.05, num_leapfrog_steps: 2}) samples, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_gamma_log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, num_burnin_steps=150, seed=42) self.assertAllEqual(dict(target_calls=2), counter) expected_x = (math_ops.digamma(self._shape_param) - np.log(self._rate_param)) expected_exp_x = self._shape_param / self._rate_param log_accept_ratio_, samples_, expected_x_ = sess.run( [kernel_results.log_accept_ratio, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( expected_x_, expected_exp_x)) logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs > 0.5) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs <= 1.)
def _chain_gets_correct_expectations(self, x, independent_chain_ndims, sess, feed_dict=None): def log_gamma_log_prob(x): event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) num_results = array_ops.placeholder(np.int32, [], name="num_results") step_size = array_ops.placeholder(np.float32, [], name="step_size") num_leapfrog_steps = array_ops.placeholder(np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} feed_dict.update({ num_results: 150, step_size: 0.1, num_leapfrog_steps: 2 }) samples, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_gamma_log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, num_burnin_steps=150, seed=42) expected_x = (math_ops.digamma(self._shape_param) - np.log(self._rate_param)) expected_exp_x = self._shape_param / self._rate_param acceptance_probs_, samples_, expected_x_ = sess.run( [kernel_results.acceptance_probs, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() logging_ops.vlog( 1, "True E[x, exp(x)]: {}\t{}".format(expected_x_, expected_exp_x)) logging_ops.vlog( 1, "Estimated E[x, exp(x)]: {}\t{}".format(actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) self.assertTrue((acceptance_probs_ > 0.5).all()) self.assertTrue((acceptance_probs_ <= 1.0).all())
def _testChainWorksDtype(self, dtype): states, kernel_results = hmc.sample_chain( num_results=10, target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), current_state=np.zeros(5).astype(dtype), step_size=0.01, num_leapfrog_steps=10, seed=48) with self.test_session() as sess: states_, acceptance_probs_ = sess.run( [states, kernel_results.acceptance_probs]) self.assertEqual(dtype, states_.dtype) self.assertEqual(dtype, acceptance_probs_.dtype)
def _testChainWorksDtype(self, dtype): states, kernel_results = hmc.sample_chain( num_results=10, target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), current_state=np.zeros(5).astype(dtype), step_size=0.01, num_leapfrog_steps=10, seed=48) with self.test_session() as sess: states_, acceptance_probs_ = sess.run( [states, kernel_results.acceptance_probs]) self.assertEqual(dtype, states_.dtype) self.assertEqual(dtype, acceptance_probs_.dtype)
def _chain_gets_correct_expectations(self, x, independent_chain_ndims, sess, feed_dict=None): def log_gamma_log_prob(x): event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) num_results = array_ops.placeholder( np.int32, [], name="num_results") step_size = array_ops.placeholder( np.float32, [], name="step_size") num_leapfrog_steps = array_ops.placeholder( np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} feed_dict.update({num_results: 150, step_size: 0.1, num_leapfrog_steps: 2}) samples, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_gamma_log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, num_burnin_steps=150, seed=42) expected_x = (math_ops.digamma(self._shape_param) - np.log(self._rate_param)) expected_exp_x = self._shape_param / self._rate_param acceptance_probs_, samples_, expected_x_ = sess.run( [kernel_results.acceptance_probs, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( expected_x_, expected_exp_x)) logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) self.assertTrue((acceptance_probs_ > 0.5).all()) self.assertTrue((acceptance_probs_ <= 1.0).all())
def testChainWorksCorrelatedMultivariate(self): dtype = np.float32 true_mean = dtype([0, 0]) true_cov = dtype([[1, 0.5], [0.5, 1]]) num_results = 2000 counter = collections.Counter() with self.test_session(graph=ops.Graph()) as sess: def target_log_prob(x, y): counter["target_calls"] += 1 # Corresponds to unnormalized MVN. # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) z = array_ops.stack([x, y], axis=-1) - true_mean z = array_ops.squeeze( gen_linalg_ops.matrix_triangular_solve( np.linalg.cholesky(true_cov), z[..., array_ops.newaxis]), axis=-1) return -0.5 * math_ops.reduce_sum(z**2., axis=-1) states, _ = hmc.sample_chain( num_results=num_results, target_log_prob_fn=target_log_prob, current_state=[dtype(-2), dtype(2)], step_size=[0.5, 0.5], num_leapfrog_steps=2, num_burnin_steps=200, num_steps_between_results=1, seed=54) self.assertAllEqual(dict(target_calls=2), counter) states = array_ops.stack(states, axis=-1) self.assertEqual(num_results, states.shape[0].value) sample_mean = math_ops.reduce_mean(states, axis=0) x = states - sample_mean sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) [sample_mean_, sample_cov_] = sess.run([ sample_mean, sample_cov]) self.assertAllClose(true_mean, sample_mean_, atol=0.05, rtol=0.) self.assertAllClose(true_cov, sample_cov_, atol=0., rtol=0.1)
def testChainWorksCorrelatedMultivariate(self): dtype = np.float32 true_mean = dtype([0, 0]) true_cov = dtype([[1, 0.5], [0.5, 1]]) num_results = 2000 counter = collections.Counter() with self.test_session(graph=ops.Graph()) as sess: def target_log_prob(x, y): counter["target_calls"] += 1 # Corresponds to unnormalized MVN. # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) z = array_ops.stack([x, y], axis=-1) - true_mean z = array_ops.squeeze( gen_linalg_ops.matrix_triangular_solve( np.linalg.cholesky(true_cov), z[..., array_ops.newaxis]), axis=-1) return -0.5 * math_ops.reduce_sum(z**2., axis=-1) states, _ = hmc.sample_chain( num_results=num_results, target_log_prob_fn=target_log_prob, current_state=[dtype(-2), dtype(2)], step_size=[0.5, 0.5], num_leapfrog_steps=2, num_burnin_steps=200, num_steps_between_results=1, seed=54) self.assertAllEqual(dict(target_calls=2), counter) states = array_ops.stack(states, axis=-1) self.assertEqual(num_results, states.shape[0].value) sample_mean = math_ops.reduce_mean(states, axis=0) x = states - sample_mean sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) [sample_mean_, sample_cov_] = sess.run([ sample_mean, sample_cov]) self.assertAllClose(true_mean, sample_mean_, atol=0.05, rtol=0.) self.assertAllClose(true_cov, sample_cov_, atol=0., rtol=0.1)
def testKernelResultsUsingTruncatedDistribution(self): def log_prob(x): return array_ops.where( x >= 0., -x - x**2, # Non-constant gradient. array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) # This log_prob has the property that it is likely to attract # the HMC flow toward, and below, zero...but for x <=0, # log_prob(x) = -inf, which should result in rejection, as well # as a non-finite log_prob. Thus, this distribution gives us an opportunity # to test out the kernel results ability to correctly capture rejections due # to finite AND non-finite reasons. # Why use a non-constant gradient? This ensures the leapfrog integrator # will not be exact. num_results = 1000 # Large step size, will give rejections due to integration error in addition # to rejection due to going into a region of log_prob = -inf. step_size = 0.1 num_leapfrog_steps = 5 num_chains = 2 with self.test_session(graph=ops.Graph()) as sess: # Start multiple independent chains. initial_state = ops.convert_to_tensor([0.1] * num_chains) states, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_prob, current_state=initial_state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, seed=42) states_, kernel_results_ = sess.run([states, kernel_results]) pstates_ = kernel_results_.proposed_state neg_inf_mask = np.isneginf( kernel_results_.proposed_target_log_prob) # First: Test that the mathematical properties of the above log prob # function in conjunction with HMC show up as expected in kernel_results_. # We better have log_prob = -inf some of the time. self.assertLess(0, neg_inf_mask.sum()) # We better have some rejections due to something other than -inf. self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) # We better have been accepted a decent amount, even near the end of the # chain, or else this HMC run just got stuck at some point. self.assertLess( 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) # We better not have any NaNs in proposed state or log_prob. # We may have some NaN in grads, which involve multiplication/addition due # to gradient rules. This is the known "NaN grad issue with tf.where." self.assertAllEqual( np.zeros_like(states_), np.isnan(kernel_results_.proposed_target_log_prob)) self.assertAllEqual(np.zeros_like(states_), np.isnan(states_)) # We better not have any +inf in states, grads, or log_prob. self.assertAllEqual( np.zeros_like(states_), np.isposinf(kernel_results_.proposed_target_log_prob)) self.assertAllEqual( np.zeros_like(states_), np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) self.assertAllEqual(np.zeros_like(states_), np.isposinf(states_)) # Second: Test that kernel_results is congruent with itself and # acceptance/rejection of states. # Proposed state is negative iff proposed target log prob is -inf. np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) # Acceptance probs are zero whenever proposed state is negative. self.assertAllEqual(np.zeros_like(pstates_[neg_inf_mask]), kernel_results_.acceptance_probs[neg_inf_mask]) # The move is accepted ==> state = proposed state. self.assertAllEqual( states_[kernel_results_.is_accepted], pstates_[kernel_results_.is_accepted], ) # The move was rejected <==> state[t] == state[t - 1]. for t in range(1, num_results): for i in range(num_chains): if kernel_results_.is_accepted[t, i]: self.assertNotEqual(states_[t, i], states_[t - 1, i]) else: self.assertEqual(states_[t, i], states_[t - 1, i])
def testKernelResultsUsingTruncatedDistribution(self): def log_prob(x): return array_ops.where( x >= 0., -x - x**2, # Non-constant gradient. array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) # This log_prob has the property that it is likely to attract # the flow toward, and below, zero...but for x <=0, # log_prob(x) = -inf, which should result in rejection, as well # as a non-finite log_prob. Thus, this distribution gives us an opportunity # to test out the kernel results ability to correctly capture rejections due # to finite AND non-finite reasons. # Why use a non-constant gradient? This ensures the leapfrog integrator # will not be exact. num_results = 1000 # Large step size, will give rejections due to integration error in addition # to rejection due to going into a region of log_prob = -inf. step_size = 0.1 num_leapfrog_steps = 5 num_chains = 2 with self.test_session(graph=ops.Graph()) as sess: # Start multiple independent chains. initial_state = ops.convert_to_tensor([0.1] * num_chains) states, kernel_results = hmc.sample_chain( num_results=num_results, target_log_prob_fn=log_prob, current_state=initial_state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, seed=42) states_, kernel_results_ = sess.run([states, kernel_results]) pstates_ = kernel_results_.proposed_state neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) # First: Test that the mathematical properties of the above log prob # function in conjunction with HMC show up as expected in kernel_results_. # We better have log_prob = -inf some of the time. self.assertLess(0, neg_inf_mask.sum()) # We better have some rejections due to something other than -inf. self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) # We better have accepted a decent amount, even near end of the chain. self.assertLess( 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) # We better not have any NaNs in states or log_prob. # We may have some NaN in grads, which involve multiplication/addition due # to gradient rules. This is the known "NaN grad issue with tf.where." self.assertAllEqual(np.zeros_like(states_), np.isnan(kernel_results_.proposed_target_log_prob)) self.assertAllEqual(np.zeros_like(states_), np.isnan(states_)) # We better not have any +inf in states, grads, or log_prob. self.assertAllEqual(np.zeros_like(states_), np.isposinf(kernel_results_.proposed_target_log_prob)) self.assertAllEqual( np.zeros_like(states_), np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) self.assertAllEqual(np.zeros_like(states_), np.isposinf(states_)) # Second: Test that kernel_results is congruent with itself and # acceptance/rejection of states. # Proposed state is negative iff proposed target log prob is -inf. np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) # Acceptance probs are zero whenever proposed state is negative. acceptance_probs = np.exp(np.minimum( kernel_results_.log_accept_ratio, 0.)) self.assertAllEqual( np.zeros_like(pstates_[neg_inf_mask]), acceptance_probs[neg_inf_mask]) # The move is accepted ==> state = proposed state. self.assertAllEqual( states_[kernel_results_.is_accepted], pstates_[kernel_results_.is_accepted], ) # The move was rejected <==> state[t] == state[t - 1]. for t in range(1, num_results): for i in range(num_chains): if kernel_results_.is_accepted[t, i]: self.assertNotEqual(states_[t, i], states_[t - 1, i]) else: self.assertEqual(states_[t, i], states_[t - 1, i])