def _test_slicing(self, data, kernel_name, kernel, feature_dim, feature_ndims): example_ndims = data.draw(hps.integers(min_value=0, max_value=2)) batch_shape = kernel.batch_shape slices = data.draw(tfp_hps.valid_slices(batch_shape)) slice_str = 'kernel[{}]'.format(', '.join( tfp_hps.stringify_slices(slices))) # Make sure the slice string appears in Hypothesis' attempted example log hp.note('Using slice ' + slice_str) if not slices: # Nothing further to check. return sliced_zeros = np.zeros(batch_shape)[slices] sliced_kernel = kernel[slices] hp.note('Using sliced kernel {}.'.format(sliced_kernel)) hp.note('Using sliced zeros {}.'.format(sliced_zeros.shape)) # Check that slicing modifies batch shape as expected. self.assertAllEqual(sliced_zeros.shape, sliced_kernel.batch_shape) xs = tf.identity( data.draw( kernel_hps.kernel_input(batch_shape=[], example_ndims=example_ndims, feature_dim=feature_dim, feature_ndims=feature_ndims))) # Check that apply of sliced kernels executes. with tfp_hps.no_tf_rank_errors(): results = self.evaluate( kernel.apply(xs, xs, example_ndims=example_ndims)) hp.note('Using results shape {}.'.format(results.shape)) sliced_results = self.evaluate( sliced_kernel.apply(xs, xs, example_ndims=example_ndims)) # Come up with the slices for apply (which must also include example dims). apply_slices = (tuple(slices) if isinstance( slices, collections.Sequence) else (slices, )) apply_slices += tuple([slice(None)] * example_ndims) # Check that sampling a sliced kernel produces the same shape as # slicing the samples from the original. self.assertAllClose(results[apply_slices], sliced_results)
def _test_slicing(self, data, dist_name, dist): strm = test_util.test_seed_stream() batch_shape = dist.batch_shape slices = data.draw(tfp_hps.valid_slices(batch_shape)) slice_str = 'dist[{}]'.format(', '.join( tfp_hps.stringify_slices(slices))) # Make sure the slice string appears in Hypothesis' attempted example log hp.note('Using slice ' + slice_str) if not slices: # Nothing further to check. return sliced_zeros = np.zeros(batch_shape)[slices] sliced_dist = dist[slices] hp.note('Using sliced distribution {}.'.format(sliced_dist)) # Check that slicing modifies batch shape as expected. self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape) if not sliced_zeros.size: # TODO(b/128924708): Fix distributions that fail on degenerate empty # shapes, e.g. Multinomial, DirichletMultinomial, ... return # Check that sampling of sliced distributions executes. with tfp_hps.no_tf_rank_errors(): samples = self.evaluate(dist.sample(seed=strm())) sliced_dist_samples = self.evaluate( sliced_dist.sample(seed=strm())) # Come up with the slices for samples (which must also include event dims). sample_slices = (tuple(slices) if isinstance( slices, collections.Sequence) else (slices, )) if Ellipsis not in sample_slices: sample_slices += (Ellipsis, ) sample_slices += tuple([slice(None)] * tensorshape_util.rank(dist.event_shape)) sliced_samples = samples[sample_slices] # Report sub-sliced samples (on which we compare log_prob) to hypothesis. hp.note('Sample(s) for testing log_prob ' + str(sliced_samples)) # Check that sampling a sliced distribution produces the same shape as # slicing the samples from the original. self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape) # Check that the sliced dist's log_prob agrees with slicing the original's # log_prob. # First, we make sure that the original sample we have passes the # original distribution's validations. We break the bijector cache here # because slicing will break it later too. with tfp_hps.no_tf_rank_errors(): try: lp = self.evaluate( dist.log_prob(samples + tf.constant(0, dtype=samples.dtype))) except tf.errors.InvalidArgumentError: # TODO(b/129271256): d.log_prob(d.sample()) should not fail # validate_args checks. # `return` here passes the example. If we `hp.assume(False)` # instead, that would demand from Hypothesis that it find many # examples where this check (and the next one) passes; # empirically, it seems to complain that that's too hard. return # This `hp.assume` is suppressing array sizes that cause the sliced and # non-sliced distribution to follow different Eigen code paths. Those # different code paths lead to arbitrarily large variations in the results # at parameter settings that Hypothesis is all too good at finding. Since # the purpose of this test is just to check that we got slicing right, those # discrepancies are a distraction. # TODO(b/140229057): Remove this `hp.assume`, if and when Eigen's numerics # become index-independent. all_packetized = (_all_packetized(dist) and _all_packetized(sliced_dist) and _all_packetized(samples) and _all_packetized(sliced_samples)) hp.note('Packetization check {}'.format(all_packetized)) all_non_packetized = (_all_non_packetized(dist) and _all_non_packetized(sliced_dist) and _all_non_packetized(samples) and _all_non_packetized(sliced_samples)) hp.note('Non-packetization check {}'.format(all_non_packetized)) hp.assume(all_packetized or all_non_packetized) # Actually evaluate and test the sliced log_prob with tfp_hps.no_tf_rank_errors(): sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples)) self.assertAllClose(lp[slices], sliced_lp, atol=SLICING_LOGPROB_ATOL[dist_name], rtol=SLICING_LOGPROB_RTOL[dist_name])