Example #1
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
    initialization = tfd.MultivariateNormalFullCovariance(
        loc=tf.zeros(event_size),
        covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4)
    samples, leapfrogs = run_nuts_chain(event_size,
                                        batch_size,
                                        num_steps=1,
                                        initial_state=initialization,
                                        **kwargs)
    answer = samples[0][-1]
    check_cdf_agrees = (
        st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
            answer, initialization, num_projections=100, false_fail_rate=1e-6))
    check_sample_shape = tf1.assert_equal(
        tf.shape(input=answer)[0], batch_size)
    unique, _ = tf.unique(leapfrogs[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.linalg.norm(tensor=answer - initialization, axis=-1)
    # This movement distance (0.3) was copied from the univariate case.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 0.3)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            batch_size, batch_size, false_fail_rate=1e-8,
            false_pass_rate=1e-6), 0.055)
    return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary,
            check_leapfrogs, check_movement, check_enough_power)
Example #2
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
  strm = tfp_test_util.test_seed_stream()
  initialization = tfd.MultivariateNormalFullCovariance(
      loc=tf.zeros(event_size),
      covariance_matrix=tf.eye(event_size)).sample(
          batch_size, seed=strm())
  samples, _ = run_nuts_chain(
      event_size, batch_size, num_steps=1,
      initial_state=initialization, **kwargs)
  answer = samples[0][-1]
  check_cdf_agrees = (
      st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
          answer, initialization, num_projections=100, false_fail_rate=1e-6))
  check_sample_shape = assert_util.assert_equal(
      tf.shape(answer)[0], batch_size)
  movement = tf.linalg.norm(answer - initialization, axis=-1)
  # This movement distance (0.3) was copied from the univariate case.
  check_movement = assert_util.assert_greater_equal(
      tf.reduce_mean(movement), 0.3)
  check_enough_power = assert_util.assert_less(
      st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
          batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6),
      0.055)
  return (
      check_cdf_agrees,
      check_sample_shape,
      check_movement,
      check_enough_power,
  )
 def assert_catches_mistake(mean, cov):
   wrong_samples = rng.multivariate_normal(
       mean=mean, cov=cov, size=num_samples).astype(dtype=dtype)
   msg = 'Empirical CDFs outside joint K-S envelope'
   with self.assertRaisesOpError(msg):
     self.evaluate(
         st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
             ground_truth, wrong_samples, num_projections=100,
             false_fail_rate=1e-6, seed=strm()))
    def test_random_projections(self, dtype):
        strm = test_util.test_seed_stream()
        rng = np.random.RandomState(seed=strm() % 2**31)
        num_samples = 57000

        # Validate experiment design

        # False fail rate here is the target rate of 1e-6 divided by the number of
        # projections.
        d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            num_samples,
            num_samples,
            false_fail_rate=1e-8,
            false_pass_rate=1e-6)
        # Choose num_samples so the discrepancy is below 0.05, which should be
        # enough to detect a mean shift of around 1/8 of a standard deviation, or a
        # scale increase of around 25% (in any particular projection).
        self.assertLess(self.evaluate(d), 0.05)

        ground_truth = rng.multivariate_normal(mean=[0, 0],
                                               cov=[[1, 0.5], [0.5, 1]],
                                               size=num_samples).astype(dtype)
        more_samples = rng.multivariate_normal(mean=[0, 0],
                                               cov=[[1, 0.5], [0.5, 1]],
                                               size=num_samples).astype(dtype)
        self.evaluate(
            st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
                ground_truth,
                more_samples,
                num_projections=100,
                false_fail_rate=1e-6,
                seed=strm()))

        def assert_catches_mistake(mean, cov):
            wrong_samples = rng.multivariate_normal(
                mean=mean, cov=cov, size=num_samples).astype(dtype=dtype)
            msg = 'Empirical CDFs outside joint K-S envelope'
            with self.assertRaisesOpError(msg):
                self.evaluate(
                    st.
                    assert_multivariate_true_cdf_equal_on_projections_two_sample(
                        ground_truth,
                        wrong_samples,
                        num_projections=100,
                        false_fail_rate=1e-6,
                        seed=strm()))

        assert_catches_mistake([0, 1], [[1, 0.5], [0.5, 1]])
        assert_catches_mistake([0, 0], [[1, 0.7], [0.7, 1]])
        assert_catches_mistake([0, 0], [[1, 0.3], [0.3, 1]])