def testDistribution(self, dist_name, data):
        if dist_name in WORKING_PRECISION_TEST_BLOCK_LIST:
            self.skipTest('{} is blocked'.format(dist_name))

        def eligibility_filter(name):
            return name not in WORKING_PRECISION_TEST_BLOCK_LIST

        dist = data.draw(
            dhps.distributions(dist_name=dist_name,
                               eligibility_filter=eligibility_filter,
                               enable_vars=False,
                               validate_args=False))
        hp.note('Trying distribution {}'.format(
            self.evaluate_dict(dist.parameters)))
        seed = test_util.test_seed()
        with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
            samples = dist.sample(5, seed=seed)
            self.assertIn(samples.dtype, [tf.float32, tf.int32])
            self.assertEqual(dist.log_prob(samples).dtype, tf.float32)

        def log_prob_function(dist, x):
            return dist.log_prob(x)

        dist64 = tf.nest.map_structure(tensor_to_f64,
                                       tfe.as_composite(dist),
                                       expand_composites=True)
        with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
            result64 = log_prob_function(dist64, tensor_to_f64(samples))
        self.assertEqual(result64.dtype, tf.float64)
Exemple #2
0
    def testCompositeTensor(self, bijector_name, data):
        # Test that making a composite tensor of this bijector doesn't throw any
        # errors.
        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            batch_shape=[],
            allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
                               set(COMPOSITE_TENSOR_IS_BROKEN)))
        composite_bij = experimental.as_composite(bijector)
        flat = tf.nest.flatten(composite_bij, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(composite_bij,
                                          flat,
                                          expand_composites=True)

        # Compare forward maps before and after compositing.
        n = 3
        xs = self._draw_domain_tensor(bijector,
                                      data,
                                      event_dim,
                                      sample_shape=[n])
        before_ys = bijector.forward(xs)
        after_ys = unflat.forward(xs)
        self.assertAllClose(*self.evaluate((before_ys, after_ys)))

        # Compare inverse maps before and after compositing.
        ys = self._draw_codomain_tensor(bijector,
                                        data,
                                        event_dim,
                                        sample_shape=[n])
        before_xs = bijector.inverse(ys)
        after_xs = unflat.inverse(ys)
        self.assertAllClose(*self.evaluate((before_xs, after_xs)))
    def _test_sample_and_log_prob(self, dist_name, dist):
        seed = test_util.test_seed(sampler_type='stateless')
        num_samples = 3

        # Sample from the distribution before composite tensoring
        sample1 = self.evaluate(dist.sample(num_samples, seed=seed))
        hp.note('Drew samples {}'.format(sample1))

        # Sample from the distribution after composite tensoring
        composite_dist = experimental.as_composite(dist)
        flat = tf.nest.flatten(composite_dist, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(composite_dist,
                                          flat,
                                          expand_composites=True)
        sample2 = self.evaluate(unflat.sample(num_samples, seed=seed))
        hp.note('Drew samples {}'.format(sample2))

        # Check that the samples are the same
        self.assertAllClose(sample1, sample2)

        # Check that all the log_probs agree for the samples from before ...
        ct_lp1 = unflat.log_prob(sample1)
        orig_lp1 = dist.log_prob(sample1)
        ct_lp1_, orig_lp1_ = self.evaluate((ct_lp1, orig_lp1))
        self.assertAllClose(ct_lp1_, orig_lp1_)

        # ... and after.  (Even though they're supposed to be the same anyway.)
        ct_lp2 = unflat.log_prob(sample2)
        orig_lp2 = dist.log_prob(sample2)
        ct_lp2_, orig_lp2_ = self.evaluate((ct_lp2, orig_lp2))
        self.assertAllClose(ct_lp2_, orig_lp2_)
    def testLogProbAccuracy(self, dist_name, data):
        self.skip_if_tf1()
        dist = tfe.as_composite(
            data.draw(
                dhps.distributions(
                    dist_name=dist_name,
                    # Accuracy tools can't handle batches (yet?)
                    batch_shape=(),
                    # Variables presumably do not affect the numerics
                    enable_vars=False,
                    # Checking that samples pass validations (including in 64-bit
                    # arithmetic) is left for another test
                    validate_args=False)))
        seed = test_util.test_seed(sampler_type='stateless')
        with tfp_hps.no_tf_rank_errors():
            sample = dist.sample(seed=seed)
        if sample.dtype.is_floating:
            hp.assume(self.evaluate(tf.reduce_all(~tf.math.is_nan(sample))))
        hp.note('Testing on sample {}'.format(sample))

        as_tensors = tf.nest.flatten(dist, expand_composites=True)

        def log_prob_function(tensors, x):
            dist_ = tf.nest.pack_sequence_as(dist,
                                             tensors,
                                             expand_composites=True)
            return dist_.log_prob(x)

        with tfp_hps.finite_ground_truth_only():
            badness = nt.excess_wrong_bits(log_prob_function, as_tensors,
                                           sample)
        # TODO(axch): Lower the acceptable badness to 4, which corresponds
        # to slightly better accuracy than 1e-6 relative error for
        # well-conditioned functions.
        self.assertAllLess(badness, 20)
  def testCompositeTensor(self, bijector_name, data):

    bijector, event_dim = self._draw_bijector(
        bijector_name, data,
        batch_shape=[],
        validate_args=True,
        allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
                           set(COMPOSITE_TENSOR_IS_BROKEN)))

    # TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
    # when AutoCT is enabled for meta-bijectors and LinearOperator.
    if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
      composite_bij = experimental.as_composite(bijector)
    else:
      composite_bij = bijector

    if not tf.executing_eagerly():
      composite_bij = tf.nest.map_structure(
          lambda x: (tf.convert_to_tensor(x)  # pylint: disable=g-long-lambda
                     if isinstance(x, DeferredTensor) else x),
          composite_bij,
          expand_composites=True)

    self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
    flat = tf.nest.flatten(composite_bij, expand_composites=True)
    unflat = tf.nest.pack_sequence_as(
        composite_bij, flat, expand_composites=True)

    # Compare forward maps before and after compositing.
    n = 3
    xs = self._draw_domain_tensor(bijector, data, event_dim, sample_shape=[n])
    before_ys = bijector.forward(xs)
    after_ys = unflat.forward(xs)
    self.assertAllClose(*self.evaluate((before_ys, after_ys)))

    # Compare inverse maps before and after compositing.
    ys = self._draw_codomain_tensor(bijector, data, event_dim, sample_shape=[n])
    before_xs = bijector.inverse(ys)
    after_xs = unflat.inverse(ys)
    self.assertAllClose(*self.evaluate((before_xs, after_xs)))

    # Input to tf.function
    self.assertAllClose(
        before_ys,
        tf.function(lambda b: b.forward(xs))(composite_bij),
        rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
        atol=COMPOSITE_TENSOR_ATOL[bijector_name])

    # Forward mapping: Check differentiation through forward mapping with
    # respect to the input and parameter variables.  Also check that any
    # variables are not referenced overmuch.
    xs = self._draw_domain_tensor(bijector, data, event_dim)
    wrt_vars = [xs] + [v for v in composite_bij.trainable_variables
                       if v.dtype.is_floating]
    with tf.GradientTape() as tape:
      tape.watch(wrt_vars)
      # TODO(b/73073515): Fix graph mode gradients with bijector caching.
      ys = bijector.forward(xs + 0)
    grads = tape.gradient(ys, wrt_vars)
    assert_no_none_grad(bijector, 'forward', wrt_vars, grads)