Ejemplo n.º 1
0
    def _test_slicing(self, data, kernel_name, kernel, feature_dim,
                      feature_ndims):
        example_ndims = data.draw(hps.integers(min_value=0, max_value=2))
        batch_shape = kernel.batch_shape
        slices = data.draw(tfp_hps.valid_slices(batch_shape))
        slice_str = 'kernel[{}]'.format(', '.join(
            tfp_hps.stringify_slices(slices)))
        # Make sure the slice string appears in Hypothesis' attempted example log
        hp.note('Using slice ' + slice_str)
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_kernel = kernel[slices]
        hp.note('Using sliced kernel {}.'.format(sliced_kernel))
        hp.note('Using sliced zeros {}.'.format(sliced_zeros.shape))

        # Check that slicing modifies batch shape as expected.
        self.assertAllEqual(sliced_zeros.shape, sliced_kernel.batch_shape)

        xs = tf.identity(
            data.draw(
                kernel_hps.kernel_input(batch_shape=[],
                                        example_ndims=example_ndims,
                                        feature_dim=feature_dim,
                                        feature_ndims=feature_ndims)))

        # Check that apply of sliced kernels executes.
        with tfp_hps.no_tf_rank_errors():
            results = self.evaluate(
                kernel.apply(xs, xs, example_ndims=example_ndims))
            hp.note('Using results shape {}.'.format(results.shape))
            sliced_results = self.evaluate(
                sliced_kernel.apply(xs, xs, example_ndims=example_ndims))

        # Come up with the slices for apply (which must also include example dims).
        apply_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        apply_slices += tuple([slice(None)] * example_ndims)

        # Check that sampling a sliced kernel produces the same shape as
        # slicing the samples from the original.
        self.assertAllClose(results[apply_slices], sliced_results)
    def _test_slicing(self, data, dist_name, dist):
        strm = test_util.test_seed_stream()
        batch_shape = dist.batch_shape
        slices = data.draw(tfp_hps.valid_slices(batch_shape))
        slice_str = 'dist[{}]'.format(', '.join(
            tfp_hps.stringify_slices(slices)))
        # Make sure the slice string appears in Hypothesis' attempted example log
        hp.note('Using slice ' + slice_str)
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_dist = dist[slices]
        hp.note('Using sliced distribution {}.'.format(sliced_dist))
        # Check that slicing modifies batch shape as expected.
        self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

        if not sliced_zeros.size:
            # TODO(b/128924708): Fix distributions that fail on degenerate empty
            #     shapes, e.g. Multinomial, DirichletMultinomial, ...
            return

        # Check that sampling of sliced distributions executes.
        with tfp_hps.no_tf_rank_errors():
            samples = self.evaluate(dist.sample(seed=strm()))
            sliced_dist_samples = self.evaluate(
                sliced_dist.sample(seed=strm()))

        # Come up with the slices for samples (which must also include event dims).
        sample_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        if Ellipsis not in sample_slices:
            sample_slices += (Ellipsis, )
        sample_slices += tuple([slice(None)] *
                               tensorshape_util.rank(dist.event_shape))

        sliced_samples = samples[sample_slices]

        # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
        hp.note('Sample(s) for testing log_prob ' + str(sliced_samples))

        # Check that sampling a sliced distribution produces the same shape as
        # slicing the samples from the original.
        self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape)

        # Check that the sliced dist's log_prob agrees with slicing the original's
        # log_prob.
        # First, we make sure that the original sample we have passes the
        # original distribution's validations.  We break the bijector cache here
        # because slicing will break it later too.
        with tfp_hps.no_tf_rank_errors():
            try:
                lp = self.evaluate(
                    dist.log_prob(samples +
                                  tf.constant(0, dtype=samples.dtype)))
            except tf.errors.InvalidArgumentError:
                # TODO(b/129271256): d.log_prob(d.sample()) should not fail
                #     validate_args checks.
                # `return` here passes the example.  If we `hp.assume(False)`
                # instead, that would demand from Hypothesis that it find many
                # examples where this check (and the next one) passes;
                # empirically, it seems to complain that that's too hard.
                return

        # This `hp.assume` is suppressing array sizes that cause the sliced and
        # non-sliced distribution to follow different Eigen code paths.  Those
        # different code paths lead to arbitrarily large variations in the results
        # at parameter settings that Hypothesis is all too good at finding.  Since
        # the purpose of this test is just to check that we got slicing right, those
        # discrepancies are a distraction.
        # TODO(b/140229057): Remove this `hp.assume`, if and when Eigen's numerics
        # become index-independent.
        all_packetized = (_all_packetized(dist)
                          and _all_packetized(sliced_dist)
                          and _all_packetized(samples)
                          and _all_packetized(sliced_samples))
        hp.note('Packetization check {}'.format(all_packetized))
        all_non_packetized = (_all_non_packetized(dist)
                              and _all_non_packetized(sliced_dist)
                              and _all_non_packetized(samples)
                              and _all_non_packetized(sliced_samples))
        hp.note('Non-packetization check {}'.format(all_non_packetized))
        hp.assume(all_packetized or all_non_packetized)

        # Actually evaluate and test the sliced log_prob
        with tfp_hps.no_tf_rank_errors():
            sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))

        self.assertAllClose(lp[slices],
                            sliced_lp,
                            atol=SLICING_LOGPROB_ATOL[dist_name],
                            rtol=SLICING_LOGPROB_RTOL[dist_name])