def test_dkwm_cdf_one_sample_assertion(self, dtype): rng = np.random.RandomState(seed=0) num_samples = 13000 d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6) d = self.evaluate(d) self.assertLess(d, 0.05) # Test that the test assertion agrees that the cdf of the standard # uniform distribution is the identity. samples = rng.uniform(size=num_samples).astype(dtype=dtype) self.evaluate(st.assert_true_cdf_equal_by_dkwm( samples, lambda x: x, false_fail_rate=1e-6)) # Test that the test assertion confirms that the cdf of a # scaled uniform distribution is not the identity. with self.assertRaisesOpError('Empirical CDF outside K-S envelope'): samples = rng.uniform( low=0., high=0.9, size=num_samples).astype(dtype=dtype) self.evaluate(st.assert_true_cdf_equal_by_dkwm( samples, lambda x: x, false_fail_rate=1e-6)) # Test that the test assertion confirms that the cdf of a # shifted uniform distribution is not the identity. with self.assertRaisesOpError('Empirical CDF outside K-S envelope'): samples = rng.uniform( low=0.1, high=1.1, size=num_samples).astype(dtype=dtype) self.evaluate(st.assert_true_cdf_equal_by_dkwm( samples, lambda x: x, false_fail_rate=1e-6))
def assert_univariate_target_conservation(test, mk_target, step_size, stackless): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 target_d = mk_target() strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1) # We wrap the initial values in `tf.identity` to avoid broken gradients # resulting from a bijector cache hit, since bijectors of the same # type/parameterization now share a cache. # TODO(b/72831017): Fix broken gradients caused by bijector caching. initialization = tf.identity(target_d.sample([num_samples], seed=strm())) def target(*args): # TODO(axch): Just use target_d.log_prob directly, and accept target_d # itself as an argument instead of a maker function. Blocked by # b/128932888. It would then also be nice not to eta-expand # target_d.log_prob; that was blocked by b/122414321, but maybe tfp's port # of value_and_gradients_function fixed that bug. return mk_target().log_prob(*args) operator = tfp.experimental.mcmc.NoUTurnSampler(target, step_size=step_size, max_tree_depth=3, use_auto_batching=True, stackless=stackless, unrolled_leapfrog_steps=2, seed=strm()) result, extra = tfp.mcmc.sample_chain(num_results=num_steps, num_burnin_steps=0, current_state=initialization, kernel=operator) # Note: sample_chain puts the chain history on top, not the (independent) # chains. test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) test.assertAllEqual([num_samples], extra.leapfrogs_taken[0].shape) unique, _ = tf.unique(extra.leapfrogs_taken[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=extra.leapfrogs_taken[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_leapfrogs_vary, check_leapfrogs, check_movement)
def test_dkwm_cdf_one_sample_batch_discrete_assertion(self, dtype): rng = np.random.RandomState(seed=0) num_samples = 13000 batch_shape = [3, 2] shape = [num_samples] + batch_shape probs = [0.1, 0.2, 0.3, 0.4] samples = rng.choice(4, size=shape, p=probs).astype(dtype=dtype) def cdf(x): ones = tf.ones_like(x) answer = tf.where(x < 3, 0.6 * ones, ones) answer = tf.where(x < 2, 0.3 * ones, answer) answer = tf.where(x < 1, 0.1 * ones, answer) return tf.where(x < 0, 0 * ones, answer) def left_continuous_cdf(x): ones = tf.ones_like(x) answer = tf.where(x <= 3, 0.6 * ones, ones) answer = tf.where(x <= 2, 0.3 * ones, answer) answer = tf.where(x <= 1, 0.1 * ones, answer) return tf.where(x <= 0, 0 * ones, answer) self.evaluate(st.assert_true_cdf_equal_by_dkwm( samples, cdf, left_continuous_cdf=left_continuous_cdf, false_fail_rate=1e-6)) d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( tf.ones(batch_shape) * num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6) self.evaluate(d < 0.05)
def testSampleEmpiricalCDF(self): num_samples = 300000 dist = tfd.HalfStudentT(df=5., loc=10., scale=2., validate_args=True) samples = dist.sample(num_samples, seed=test_util.test_seed()) check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm( samples, dist.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.01) self.evaluate([check_cdf_agrees, check_enough_power])
def testSampleEmpiricalCDF(self): num_samples = 300000 temperature, low, peak, high = 2., 1., 7., 10. dist = tfd.PERT(low, peak, high, temperature, validate_args=True) samples = dist.sample(num_samples, seed=test_util.test_seed()) check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm( samples, dist.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.01) self.evaluate([check_cdf_agrees, check_enough_power])
def test_dkwm_cdf_one_sample_batch_discrete_assertion(self, dtype): rng = np.random.RandomState(seed=0) num_samples = 13000 batch_shape = [3, 2] shape = [num_samples] + batch_shape probs = [0.1, 0.2, 0.3, 0.4] samples = rng.choice(4, size=shape, p=probs).astype(dtype=dtype) def cdf(x): ones = tf.ones_like(x) answer = tf1.where(x < 3, 0.6 * ones, ones) answer = tf1.where(x < 2, 0.3 * ones, answer) answer = tf1.where(x < 1, 0.1 * ones, answer) return tf1.where(x < 0, 0 * ones, answer) def left_continuous_cdf(x): ones = tf.ones_like(x) answer = tf1.where(x <= 3, 0.6 * ones, ones) answer = tf1.where(x <= 2, 0.3 * ones, answer) answer = tf1.where(x <= 1, 0.1 * ones, answer) return tf1.where(x <= 0, 0 * ones, answer) self.evaluate( st.assert_true_cdf_equal_by_dkwm( samples, cdf, left_continuous_cdf=left_continuous_cdf, false_fail_rate=1e-6)) d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( tf.ones(batch_shape) * num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6) self.assertTrue(np.all(self.evaluate(d) < 0.05)) def check_catches_mistake(wrong_probs): wrong_samples = rng.choice(len(wrong_probs), size=shape, p=wrong_probs).astype(dtype=dtype) with self.assertRaisesOpError( 'Empirical CDF outside K-S envelope'): self.evaluate( st.assert_true_cdf_equal_by_dkwm( wrong_samples, cdf, left_continuous_cdf=left_continuous_cdf, false_fail_rate=1e-6)) check_catches_mistake([0.1, 0.2, 0.3, 0.3, 0.1]) check_catches_mistake([0.2, 0.2, 0.3, 0.3])
def assert_univariate_target_conservation(test, target_d, step_size): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 strm = test_util.test_seed_stream() # We wrap the initial values in `tf.identity` to avoid broken gradients # resulting from a bijector cache hit, since bijectors of the same # type/parameterization now share a cache. # TODO(b/72831017): Fix broken gradients caused by bijector caching. initialization = tf.identity(target_d.sample([num_samples], seed=strm())) @tf.function(autograph=False) def run_chain(): nuts = tfp.experimental.mcmc.PreconditionedNoUTurnSampler( target_d.log_prob, step_size=step_size, max_tree_depth=3, unrolled_leapfrog_steps=2) result = tfp.mcmc.sample_chain(num_results=num_steps, num_burnin_steps=0, current_state=initialization, trace_fn=None, kernel=nuts, seed=strm()) return result result = run_chain() test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = assert_util.assert_greater_equal(tf.reduce_mean(movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_movement)
def testSample(self): a = tf.constant(1.0) b = tf.constant(2.0) n = 500000 d = tfd.SigmoidBeta(concentration0=a, concentration1=b, validate_args=True) samples = d.sample(n, seed=test_util.test_seed()) sample_values = self.evaluate(samples) self.assertEqual(samples.shape, (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(a, b, sample_values)) check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm( samples, d.cdf, false_fail_rate=1e-6) self.evaluate(check_cdf_agrees) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( n, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.01) self.evaluate(check_enough_power)
def assert_univariate_target_conservation(test, target_d, step_size): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1) initialization = target_d.sample([num_samples], seed=strm()) @tf.function(autograph=False) def run_chain(): nuts = tfp.mcmc.NoUTurnSampler( target_d.log_prob, step_size=step_size, max_tree_depth=3, unrolled_leapfrog_steps=2, seed=strm()) result, _ = tfp.mcmc.sample_chain( num_results=num_steps, num_burnin_steps=0, current_state=initialization, kernel=nuts) return result result = run_chain() test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm( answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = assert_util.assert_greater_equal( tf.reduce_mean(movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_movement)