def assert_mvn_target_conservation(event_size, batch_size, **kwargs): initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4) samples, leapfrogs = run_nuts_chain(event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = tf1.assert_equal( tf.shape(input=answer)[0], batch_size) unique, _ = tf.unique(leapfrogs[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.linalg.norm(tensor=answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 0.3) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary, check_leapfrogs, check_movement, check_enough_power)
def assert_mvn_target_conservation(event_size, batch_size, **kwargs): strm = tfp_test_util.test_seed_stream() initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample( batch_size, seed=strm()) samples, _ = run_nuts_chain( event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = assert_util.assert_equal( tf.shape(answer)[0], batch_size) movement = tf.linalg.norm(answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = assert_util.assert_greater_equal( tf.reduce_mean(movement), 0.3) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return ( check_cdf_agrees, check_sample_shape, check_movement, check_enough_power, )
def assert_catches_mistake(mean, cov): wrong_samples = rng.multivariate_normal( mean=mean, cov=cov, size=num_samples).astype(dtype=dtype) msg = 'Empirical CDFs outside joint K-S envelope' with self.assertRaisesOpError(msg): self.evaluate( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( ground_truth, wrong_samples, num_projections=100, false_fail_rate=1e-6, seed=strm()))
def test_random_projections(self, dtype): strm = test_util.test_seed_stream() rng = np.random.RandomState(seed=strm() % 2**31) num_samples = 57000 # Validate experiment design # False fail rate here is the target rate of 1e-6 divided by the number of # projections. d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( num_samples, num_samples, false_fail_rate=1e-8, false_pass_rate=1e-6) # Choose num_samples so the discrepancy is below 0.05, which should be # enough to detect a mean shift of around 1/8 of a standard deviation, or a # scale increase of around 25% (in any particular projection). self.assertLess(self.evaluate(d), 0.05) ground_truth = rng.multivariate_normal(mean=[0, 0], cov=[[1, 0.5], [0.5, 1]], size=num_samples).astype(dtype) more_samples = rng.multivariate_normal(mean=[0, 0], cov=[[1, 0.5], [0.5, 1]], size=num_samples).astype(dtype) self.evaluate( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( ground_truth, more_samples, num_projections=100, false_fail_rate=1e-6, seed=strm())) def assert_catches_mistake(mean, cov): wrong_samples = rng.multivariate_normal( mean=mean, cov=cov, size=num_samples).astype(dtype=dtype) msg = 'Empirical CDFs outside joint K-S envelope' with self.assertRaisesOpError(msg): self.evaluate( st. assert_multivariate_true_cdf_equal_on_projections_two_sample( ground_truth, wrong_samples, num_projections=100, false_fail_rate=1e-6, seed=strm())) assert_catches_mistake([0, 1], [[1, 0.5], [0.5, 1]]) assert_catches_mistake([0, 0], [[1, 0.7], [0.7, 1]]) assert_catches_mistake([0, 0], [[1, 0.3], [0.3, 1]])