def testDistribution(self, dist_name, data):
    dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False,
                                        validate_args=False))
    seed = test_util.test_seed(sampler_type='stateless')
    sample_shape = [2, 1]
    with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
      s1, lp1 = dist.experimental_sample_and_log_prob(sample_shape, seed=seed)
      s2 = dist.sample(sample_shape, seed=seed)
      self.assertAllClose(s1, s2, atol=1e-4)

      # Sanity-check the log prob. The actual values may differ arbitrarily (if
      # the `sample_and_log_prob` implementation is more stable) or be NaN, but
      # they should at least have the same shape.
      lp2 = dist.log_prob(s1)
      self.assertAllEqual(lp1.shape, lp2.shape)
예제 #2
0
  def testDistribution(self, dist_name, data):
    if dist_name in WORKING_PRECISION_TEST_BLOCK_LIST:
      self.skipTest('{} is blocked'.format(dist_name))
    def eligibility_filter(name):
      return name not in WORKING_PRECISION_TEST_BLOCK_LIST
    dist = data.draw(dhps.distributions(
        dist_name=dist_name, eligibility_filter=eligibility_filter,
        enable_vars=False, validate_args=False))
    hp.note('Trying distribution {}'.format(
        self.evaluate_dict(dist.parameters)))
    seed = test_util.test_seed()
    with tfp_hps.no_tf_rank_errors():
      samples = dist.sample(5, seed=seed)
      self.assertIn(samples.dtype, [tf.float32, tf.int32])
      self.assertEqual(dist.log_prob(samples).dtype, tf.float32)

    def log_prob_function(dist, x):
      return dist.log_prob(x)

    dist64 = tf.nest.map_structure(
        tensor_to_f64, tfe.as_composite(dist), expand_composites=True)
    with tfp_hps.no_tf_rank_errors():
      result64 = log_prob_function(dist64, tensor_to_f64(samples))
    self.assertEqual(result64.dtype, tf.float64)
    def testDistribution(self, dist_name, data):
        if dist_name in NO_NANS_IN_SAMPLE_TEST_BLOCK_LIST:
            self.skipTest('{} is blocked'.format(dist_name))

        def eligibility_filter(name):
            return name not in NO_NANS_IN_SAMPLE_TEST_BLOCK_LIST

        dist = data.draw(
            dhps.distributions(dist_name=dist_name,
                               enable_vars=False,
                               eligibility_filter=eligibility_filter))
        hp.note('Trying distribution {}'.format(
            self.evaluate_dict(dist.parameters)))
        seed = test_util.test_seed(sampler_type='stateless')
        with tfp_hps.no_tf_rank_errors():
            s1 = self.evaluate(dist.sample(20, seed=seed))
        self.assertAllEqual(np.zeros_like(s1), np.isnan(s1))
  def testKernelGradient(self, kernel_name, data):
    event_dim = data.draw(hps.integers(min_value=2, max_value=4))
    feature_ndims = data.draw(hps.integers(min_value=1, max_value=2))
    feature_dim = data.draw(hps.integers(min_value=2, max_value=4))

    kernel, kernel_parameter_variable_names = data.draw(
        kernel_hps.kernels(
            kernel_name=kernel_name,
            event_dim=event_dim,
            feature_dim=feature_dim,
            feature_ndims=feature_ndims,
            enable_vars=True))

    # Check that variable parameters get passed to the kernel.variables
    kernel_variables_names = [
        v.name.strip('_0123456789:') for v in kernel.variables]
    self.assertEqual(
        set(kernel_parameter_variable_names),
        set(kernel_variables_names))

    example_ndims = data.draw(hps.integers(min_value=1, max_value=2))
    input_batch_shape = data.draw(tfp_hps.broadcast_compatible_shape(
        kernel.batch_shape))
    xs = tf.identity(data.draw(kernel_hps.kernel_input(
        batch_shape=input_batch_shape,
        example_ndims=example_ndims,
        feature_dim=feature_dim,
        feature_ndims=feature_ndims)))

    # Check that we pick up all relevant kernel parameters.
    wrt_vars = [xs] + list(kernel.variables)
    self.evaluate([v.initializer for v in kernel.variables])

    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `apply` of {}'.format(kernel)):
        tape.watch(wrt_vars)
        with tfp_hps.no_tf_rank_errors():
          diag = kernel.apply(xs, xs, example_ndims=example_ndims)
    grads = tape.gradient(diag, wrt_vars)
    assert_no_none_grad(kernel, 'apply', wrt_vars, grads)

    self.assertAllClose(
        diag,
        type(kernel)(**kernel._parameters).apply(
            xs, xs, example_ndims=example_ndims))
예제 #5
0
 def check_statistic(self, dist, statistic, expected_static_shape,
                     expected_dynamic_shape):
     try:
         with tfp_hps.no_tf_rank_errors():
             result = getattr(dist, statistic)()
         msg = 'Shape {} not compatible with expected {}.'.format(
             result.shape, expected_static_shape)
         self.assertTrue(
             expected_static_shape.is_compatible_with(
                 tf.broadcast_static_shape(result.shape,
                                           expected_static_shape)), msg)
         self.assertAllEqual(
             self.evaluate(expected_dynamic_shape),
             self.evaluate(
                 tf.broadcast_dynamic_shape(tf.shape(result),
                                            expected_dynamic_shape)))
     except NotImplementedError:
         pass
    def testDistribution(self, dist_name, data):
        if dist_name in NO_NANS_TEST_BLOCK_LIST:
            self.skipTest('{} is blocked'.format(dist_name))

        def eligibility_filter(name):
            return name not in NO_NANS_TEST_BLOCK_LIST

        dist = data.draw(
            dhps.distributions(dist_name=dist_name,
                               enable_vars=False,
                               eligibility_filter=eligibility_filter))
        samples = self.check_samples_not_nan(dist)
        self.assume_loc_scale_ok(dist)

        hp.note('Testing on samples {}'.format(samples))
        with tfp_hps.no_tf_rank_errors():
            lp = self.evaluate(dist.log_prob(samples))
        self.assertAllEqual(np.zeros_like(lp), np.isnan(lp))
    def _test_sample_and_log_prob(self, dist_name, dist):
        seed = test_util.test_seed(sampler_type='stateless')

        num_samples = 3
        sample = self.evaluate(
            tf.function(dist.sample, jit_compile=True)(num_samples, seed=seed))
        hp.note('Trying distribution {}'.format(
            self.evaluate_dict(dist.parameters)))
        hp.note('Drew samples {}'.format(sample))

        xla_lp = self.evaluate(
            tf.function(dist.log_prob,
                        jit_compile=True)(tf.convert_to_tensor(sample)))
        with tfp_hps.no_tf_rank_errors():
            graph_lp = self.evaluate(dist.log_prob(sample))
        self.assertAllClose(xla_lp,
                            graph_lp,
                            atol=XLA_LOGPROB_ATOL[dist_name],
                            rtol=XLA_LOGPROB_RTOL[dist_name])
예제 #8
0
    def _test_slicing(self, data, kernel_name, kernel, feature_dim,
                      feature_ndims):
        example_ndims = data.draw(hps.integers(min_value=0, max_value=2))
        batch_shape = kernel.batch_shape
        slices = data.draw(tfp_hps.valid_slices(batch_shape))
        slice_str = 'kernel[{}]'.format(', '.join(
            tfp_hps.stringify_slices(slices)))
        # Make sure the slice string appears in Hypothesis' attempted example log
        hp.note('Using slice ' + slice_str)
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_kernel = kernel[slices]
        hp.note('Using sliced kernel {}.'.format(sliced_kernel))
        hp.note('Using sliced zeros {}.'.format(sliced_zeros.shape))

        # Check that slicing modifies batch shape as expected.
        self.assertAllEqual(sliced_zeros.shape, sliced_kernel.batch_shape)

        xs = tf.identity(
            data.draw(
                kernel_hps.kernel_input(batch_shape=[],
                                        example_ndims=example_ndims,
                                        feature_dim=feature_dim,
                                        feature_ndims=feature_ndims)))

        # Check that apply of sliced kernels executes.
        with tfp_hps.no_tf_rank_errors():
            results = self.evaluate(
                kernel.apply(xs, xs, example_ndims=example_ndims))
            hp.note('Using results shape {}.'.format(results.shape))
            sliced_results = self.evaluate(
                sliced_kernel.apply(xs, xs, example_ndims=example_ndims))

        # Come up with the slices for apply (which must also include example dims).
        apply_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        apply_slices += tuple([slice(None)] * example_ndims)

        # Check that sampling a sliced kernel produces the same shape as
        # slicing the samples from the original.
        self.assertAllClose(results[apply_slices], sliced_results)
    def check_event_space_bijector_constrains(self, dist, data):
        event_space_bijector = dist.experimental_default_event_space_bijector()
        if event_space_bijector is None:
            return

        # Draw a sample shape
        sample_shape = data.draw(tfp_hps.shapes())
        inv_event_shape = event_space_bijector.inverse_event_shape(
            tensorshape_util.concatenate(dist.batch_shape, dist.event_shape))

        # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
        # where `inverse_event_shape` is the event shape in the bijector's
        # domain. This is the shape of `y` in R**n, such that
        # x = event_space_bijector(y) has the event shape of the distribution.

        # TODO(b/174778703): Actually draw broadcast compatible shapes.
        batch_inv_event_compat_shape = inv_event_shape
        # batch_inv_event_compat_shape = data.draw(
        #     tfp_hps.broadcast_compatible_shape(inv_event_shape))
        # batch_inv_event_compat_shape = tensorshape_util.concatenate(
        #     (1,) * (len(inv_event_shape) - len(batch_inv_event_compat_shape)),
        #     batch_inv_event_compat_shape)

        total_sample_shape = tensorshape_util.concatenate(
            sample_shape, batch_inv_event_compat_shape)
        # full_sample_batch_event_shape = tensorshape_util.concatenate(
        #     sample_shape, inv_event_shape)

        y = data.draw(
            tfp_hps.constrained_tensors(tfp_hps.identity_fn,
                                        total_sample_shape.as_list()))
        hp.note('Trying to constrain inputs {}'.format(y))
        with tfp_hps.no_tf_rank_errors():
            x = event_space_bijector(y)
            hp.note('Got constrained samples {}'.format(x))
            with tf.control_dependencies(dist._sample_control_dependencies(x)):
                self.evaluate(tensor_util.identity_as_tensor(x))
예제 #10
0
  def check_event_space_bijector_constrains(self, dist, data):
    event_space_bijector = dist.experimental_default_event_space_bijector()
    if event_space_bijector is None:
      return

    total_sample_shape = tensorshape_util.concatenate(
        # Draw a sample shape
        data.draw(tfp_hps.shapes()),
        # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
        # where `inverse_event_shape` is the event shape in the bijector's
        # domain. This is the shape of `y` in R**n, such that
        # x = event_space_bijector(y) has the event shape of the distribution.
        data.draw(tfp_hps.broadcasting_shapes(
            event_space_bijector.inverse_event_shape(
                tensorshape_util.concatenate(
                    dist.batch_shape, dist.event_shape)), n=1))[0])

    y = data.draw(
        tfp_hps.constrained_tensors(
            tfp_hps.identity_fn, total_sample_shape.as_list()))
    with tfp_hps.no_tf_rank_errors():
      x = event_space_bijector(y)
      with tf.control_dependencies(dist._sample_control_dependencies(x)):
        self.evaluate(tf.identity(x))
    def _test_slicing(self, data, dist_name, dist):
        strm = test_util.test_seed_stream()
        batch_shape = dist.batch_shape
        slices = data.draw(dhps.valid_slices(batch_shape))
        slice_str = 'dist[{}]'.format(', '.join(dhps.stringify_slices(slices)))
        # Make sure the slice string appears in Hypothesis' attempted example log
        hp.note('Using slice ' + slice_str)
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_dist = dist[slices]
        hp.note('Using sliced distribution {}.'.format(sliced_dist))

        # Check that slicing modifies batch shape as expected.
        self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

        if not sliced_zeros.size:
            # TODO(b/128924708): Fix distributions that fail on degenerate empty
            #     shapes, e.g. Multinomial, DirichletMultinomial, ...
            return

        # Check that sampling of sliced distributions executes.
        with tfp_hps.no_tf_rank_errors():
            samples = self.evaluate(dist.sample(seed=strm()))
            sliced_dist_samples = self.evaluate(
                sliced_dist.sample(seed=strm()))

        # Come up with the slices for samples (which must also include event dims).
        sample_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        if Ellipsis not in sample_slices:
            sample_slices += (Ellipsis, )
        sample_slices += tuple([slice(None)] *
                               tensorshape_util.rank(dist.event_shape))

        sliced_samples = samples[sample_slices]

        # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
        hp.note('Sample(s) for testing log_prob ' + str(sliced_samples))

        # Check that sampling a sliced distribution produces the same shape as
        # slicing the samples from the original.
        self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape)

        # Check that a sliced distribution can compute the log_prob of its own
        # samples (up to numerical validation errors).
        with tfp_hps.no_tf_rank_errors():
            try:
                lp = self.evaluate(dist.log_prob(samples))
            except tf.errors.InvalidArgumentError:
                # TODO(b/129271256): d.log_prob(d.sample()) should not fail
                #     validate_args checks.
                # We only tolerate this case for the non-sliced dist.
                return
            sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))

        # Check that the sliced dist's log_prob agrees with slicing the original's
        # log_prob.

        # This `hp.assume` is suppressing array sizes that cause the sliced and
        # non-sliced distribution to follow different Eigen code paths.  Those
        # different code paths lead to arbitrarily large variations in the results
        # at parameter settings that Hypothesis is all too good at finding.  Since
        # the purpose of this test is just to check that we got slicing right, those
        # discrepancies are a distraction.
        # TODO(b/140229057): Remove this `hp.assume`, if and when Eigen's numerics
        # become index-independent.
        all_packetized = (_all_packetized(dist)
                          and _all_packetized(sliced_dist)
                          and _all_packetized(samples)
                          and _all_packetized(sliced_samples))
        hp.note('Packetization check {}'.format(all_packetized))
        all_non_packetized = (_all_non_packetized(dist)
                              and _all_non_packetized(sliced_dist)
                              and _all_non_packetized(samples)
                              and _all_non_packetized(sliced_samples))
        hp.note('Non-packetization check {}'.format(all_non_packetized))
        hp.assume(all_packetized or all_non_packetized)

        self.assertAllClose(lp[slices],
                            sliced_lp,
                            atol=SLICING_LOGPROB_ATOL[dist_name],
                            rtol=SLICING_LOGPROB_RTOL[dist_name])
 def testVmap(self, dist_name, data):
   dist = data.draw(dhps.distributions(
       dist_name=dist_name, enable_vars=False,
       validate_args=False))  # TODO(b/142826246): Enable validate_args.
   with tfp_hps.no_tf_rank_errors():
     self._test_vectorization(dist_name, dist)
 def testCompositeTensor(self, dist_name, data):
   dist = data.draw(
       dhps.distributions(
           dist_name=dist_name, enable_vars=False, validate_args=False))
   with tfp_hps.no_tf_rank_errors():
     self._test_sample_and_log_prob(dist_name, dist)