示例#1
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
    initialization = tfd.MultivariateNormalFullCovariance(
        loc=tf.zeros(event_size),
        covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4)
    samples, leapfrogs = run_nuts_chain(event_size,
                                        batch_size,
                                        num_steps=1,
                                        initial_state=initialization,
                                        **kwargs)
    answer = samples[0][-1]
    check_cdf_agrees = (
        st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
            answer, initialization, num_projections=100, false_fail_rate=1e-6))
    check_sample_shape = tf1.assert_equal(
        tf.shape(input=answer)[0], batch_size)
    unique, _ = tf.unique(leapfrogs[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.linalg.norm(tensor=answer - initialization, axis=-1)
    # This movement distance (0.3) was copied from the univariate case.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 0.3)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            batch_size, batch_size, false_fail_rate=1e-8,
            false_pass_rate=1e-6), 0.055)
    return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary,
            check_leapfrogs, check_movement, check_enough_power)
  def test_dkwm_cdf_two_sample_batch_discrete_assertion(self, dtype):
    rng = np.random.RandomState(seed=0)
    num_samples = 52000
    batch_shape = [3, 2]
    shape = [num_samples] + batch_shape

    probs = [0.1, 0.2, 0.3, 0.4]
    samples1 = rng.choice(4, size=shape, p=probs).astype(dtype=dtype)
    samples2 = rng.choice(4, size=shape, p=probs).astype(dtype=dtype)
    self.evaluate(st.assert_true_cdf_equal_by_dkwm_two_sample(
        samples1, samples2, false_fail_rate=1e-6))

    def check_catches_mistake(wrong_probs):
      wrong_samples = rng.choice(
          len(wrong_probs), size=shape, p=wrong_probs).astype(dtype=dtype)
      with self.assertRaisesOpError(
          'Empirical CDFs outside joint K-S envelope'):
        self.evaluate(st.assert_true_cdf_equal_by_dkwm_two_sample(
            samples1, wrong_samples, false_fail_rate=1e-6))

    n = tf.ones(batch_shape) * num_samples
    d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
        n, n, false_fail_rate=1e-6, false_pass_rate=1e-6)
    self.assertTrue(np.all(self.evaluate(d) < 0.05))

    check_catches_mistake([0.1, 0.2, 0.3, 0.3, 0.1])
    check_catches_mistake([0.2, 0.2, 0.3, 0.3])
示例#3
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
  strm = tfp_test_util.test_seed_stream()
  initialization = tfd.MultivariateNormalFullCovariance(
      loc=tf.zeros(event_size),
      covariance_matrix=tf.eye(event_size)).sample(
          batch_size, seed=strm())
  samples, _ = run_nuts_chain(
      event_size, batch_size, num_steps=1,
      initial_state=initialization, **kwargs)
  answer = samples[0][-1]
  check_cdf_agrees = (
      st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
          answer, initialization, num_projections=100, false_fail_rate=1e-6))
  check_sample_shape = assert_util.assert_equal(
      tf.shape(answer)[0], batch_size)
  movement = tf.linalg.norm(answer - initialization, axis=-1)
  # This movement distance (0.3) was copied from the univariate case.
  check_movement = assert_util.assert_greater_equal(
      tf.reduce_mean(movement), 0.3)
  check_enough_power = assert_util.assert_less(
      st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
          batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6),
      0.055)
  return (
      check_cdf_agrees,
      check_sample_shape,
      check_movement,
      check_enough_power,
  )
    def test_random_projections(self, dtype):
        strm = test_util.test_seed_stream()
        rng = np.random.RandomState(seed=strm() % 2**31)
        num_samples = 57000

        # Validate experiment design

        # False fail rate here is the target rate of 1e-6 divided by the number of
        # projections.
        d = st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            num_samples,
            num_samples,
            false_fail_rate=1e-8,
            false_pass_rate=1e-6)
        # Choose num_samples so the discrepancy is below 0.05, which should be
        # enough to detect a mean shift of around 1/8 of a standard deviation, or a
        # scale increase of around 25% (in any particular projection).
        self.assertLess(self.evaluate(d), 0.05)

        ground_truth = rng.multivariate_normal(mean=[0, 0],
                                               cov=[[1, 0.5], [0.5, 1]],
                                               size=num_samples).astype(dtype)
        more_samples = rng.multivariate_normal(mean=[0, 0],
                                               cov=[[1, 0.5], [0.5, 1]],
                                               size=num_samples).astype(dtype)
        self.evaluate(
            st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
                ground_truth,
                more_samples,
                num_projections=100,
                false_fail_rate=1e-6,
                seed=strm()))

        def assert_catches_mistake(mean, cov):
            wrong_samples = rng.multivariate_normal(
                mean=mean, cov=cov, size=num_samples).astype(dtype=dtype)
            msg = 'Empirical CDFs outside joint K-S envelope'
            with self.assertRaisesOpError(msg):
                self.evaluate(
                    st.
                    assert_multivariate_true_cdf_equal_on_projections_two_sample(
                        ground_truth,
                        wrong_samples,
                        num_projections=100,
                        false_fail_rate=1e-6,
                        seed=strm()))

        assert_catches_mistake([0, 1], [[1, 0.5], [0.5, 1]])
        assert_catches_mistake([0, 0], [[1, 0.7], [0.7, 1]])
        assert_catches_mistake([0, 0], [[1, 0.3], [0.3, 1]])