Пример #1
0
class _CholeskyExtend(tf.test.TestCase):
    def testCholeskyExtension(self):
        xs = np.random.random(7).astype(self.dtype)[:, tf.newaxis]
        xs = tf1.placeholder_with_default(
            xs, shape=xs.shape if self.use_static_shape else None)
        k = tfp.positive_semidefinite_kernels.MaternOneHalf()
        mat = k.matrix(xs, xs)
        chol = tf.linalg.cholesky(mat)

        ys = np.random.random(3).astype(self.dtype)[:, tf.newaxis]
        ys = tf1.placeholder_with_default(
            ys, shape=ys.shape if self.use_static_shape else None)

        xsys = tf.concat([xs, ys], 0)
        new_chol_expected = tf.linalg.cholesky(k.matrix(xsys, xsys))

        new_chol = tfp.math.cholesky_concat(chol, k.matrix(xsys, ys))
        self.assertAllClose(new_chol_expected, new_chol)

    @hp.given(hps.data())
    @hp.settings(deadline=None,
                 max_examples=10,
                 derandomize=tfp_test_util.derandomize_hypothesis())
    def testCholeskyExtensionRandomized(self, data):
        jitter = lambda n: tf.linalg.eye(n, dtype=self.dtype) * 1e-5
        target_bs = data.draw(hpnp.array_shapes())
        prev_bs, new_bs = data.draw(
            tfp_test_util.broadcasting_shapes(target_bs, 2))
        ones = tf.TensorShape([1] * len(target_bs))
        smallest_shared_shp = tuple(
            np.min([
                tf.broadcast_static_shape(ones, shp).as_list()
                for shp in [prev_bs, new_bs]
            ],
                   axis=0))

        z = data.draw(hps.integers(min_value=1, max_value=12))
        n = data.draw(hps.integers(min_value=0, max_value=z - 1))
        m = z - n

        np.random.seed(
            data.draw(hps.integers(min_value=0, max_value=2**32 - 1)))
        xs = np.random.uniform(size=smallest_shared_shp + (n, ))
        data.draw(hps.just(xs))
        xs = (xs + np.zeros(prev_bs.as_list() + [n]))[..., np.newaxis]
        xs = xs.astype(self.dtype)
        xs = tf1.placeholder_with_default(
            xs, shape=xs.shape if self.use_static_shape else None)

        k = tfp.positive_semidefinite_kernels.MaternOneHalf()
        mat = k.matrix(xs, xs) + jitter(n)
        chol = tf.linalg.cholesky(mat)

        ys = np.random.uniform(size=smallest_shared_shp + (m, ))
        data.draw(hps.just(ys))
        ys = (ys + np.zeros(new_bs.as_list() + [m]))[..., np.newaxis]
        ys = ys.astype(self.dtype)
        ys = tf1.placeholder_with_default(
            ys, shape=ys.shape if self.use_static_shape else None)

        xsys = tf.concat([
            xs + tf.zeros(target_bs + (n, 1), dtype=self.dtype),
            ys + tf.zeros(target_bs + (m, 1), dtype=self.dtype)
        ],
                         axis=-2)
        new_chol_expected = tf.linalg.cholesky(
            k.matrix(xsys, xsys) + jitter(z))

        new_chol = tfp.math.cholesky_concat(
            chol,
            k.matrix(xsys, ys) + jitter(z)[:, n:])
        self.assertAllClose(new_chol_expected, new_chol, rtol=1e-5, atol=1e-5)
class DistributionParamsAreVarsTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters((dname, ) for dname in TF2_FRIENDLY_DISTS)
    @hp.given(hps.data())
    @hp.settings(deadline=None,
                 max_examples=hypothesis_max_examples(),
                 suppress_health_check=[hp.HealthCheck.too_slow],
                 derandomize=tfp_test_util.derandomize_hypothesis())
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        tf.compat.v1.set_random_seed(
            data.draw(
                hpnp.arrays(dtype=np.int64,
                            shape=[]).filter(lambda x: x != 0)))
        dist, batch_shape = data.draw(
            distributions(dist_name=dist_name, enable_vars=True))
        batch_shape2 = data.draw(broadcast_compatible_shape(batch_shape))
        dist2, _ = data.draw(
            distributions(dist_name=dist_name,
                          batch_shape=batch_shape2,
                          event_dim=get_event_dim(dist),
                          enable_vars=True))
        del batch_shape
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.evaluate([var.initializer for var in dist.variables])
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_mutable(v):
                continue
            try:
                self.assertIs(getattr(dist, k), v)
            except AssertionError as e:
                raise AssertionError(
                    'No attr found for parameter {} of distribution {}: \n{}'.
                    format(k, dist_name, e))

        for stat in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'covariance', 'entropy', 'mean', 'mode', 'stddev',
                        'variance'
                    ])),
                         min_size=3,
                         max_size=3)):
            logging.info('%s.%s', dist_name, stat)
            try:
                VAR_USAGES.clear()
                getattr(dist, stat)()
                assert_no_excessive_var_usage('statistic `{}` of `{}`'.format(
                    stat, dist))
            except NotImplementedError:
                pass

        VAR_USAGES.clear()
        with tf.GradientTape() as tape:
            sample = dist.sample()
        assert_no_excessive_var_usage('method `sample` of `{}`'.format(dist))
        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var, dist_name))

        # Turn off validations, since log_prob can choke on dist's own samples.
        # Also, to relax conversion counts for KL (might do >2 w/ validate_args).
        dist = dist.copy(validate_args=False)
        dist2 = dist2.copy(validate_args=False)

        try:
            for d1, d2 in (dist, dist2), (dist2, dist):
                VAR_USAGES.clear()
                with tf.GradientTape() as tape:
                    kl = d1.kl_divergence(d2)
                assert_no_excessive_var_usage(
                    '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.
                    format(d1, d1.variables, d2, d2.variables),
                    max_permissible=1)  # No validation => 1 convert per var.
                wrt_vars = list(d1.variables) + list(d2.variables)
                grads = tape.gradient(kl, wrt_vars)
                for grad, var in zip(grads, wrt_vars):
                    if grad is None and dist_name not in NO_KL_PARAM_GRADS:
                        raise AssertionError(
                            'Missing KL({} || {}) -> {} grad:\n'
                            '{} vars: {}\n{} vars: {}'.format(
                                d1, d2, var, d1, d1.variables, d2,
                                d2.variables))
        except NotImplementedError:
            pass

        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(sample))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))

        for evaluative in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'log_prob', 'prob', 'log_cdf', 'cdf',
                        'log_survival_function', 'survival_function'
                    ])),
                         min_size=3,
                         max_size=3)):
            logging.info('%s.%s', dist_name, evaluative)
            try:
                VAR_USAGES.clear()
                getattr(dist, evaluative)(sample)
                assert_no_excessive_var_usage(
                    'evaluative `{}` of `{}`'.format(evaluative, dist),
                    max_permissible=1)  # No validation => 1 convert.
            except NotImplementedError:
                pass
class DistributionSlicingTest(tf.test.TestCase):
    def _test_slicing(self, data, dist, batch_shape):
        slices = data.draw(valid_slices(batch_shape))
        slice_str = 'dist[{}]'.format(', '.join(stringify_slices(slices)))
        logging.info('slice used: %s', slice_str)
        # Make sure the slice string appears in Hypothesis' attempted example log,
        # by drawing and discarding it.
        data.draw(hps.just(slice_str))
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_dist = dist[slices]
        self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

        try:
            seed = data.draw(
                hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))
            samples = self.evaluate(dist.sample(seed=maybe_seed(seed)))

            if not sliced_zeros.size:
                # TODO(b/128924708): Fix distributions that fail on degenerate empty
                #     shapes, e.g. Multinomial, DirichletMultinomial, ...
                return

            sliced_samples = self.evaluate(
                sliced_dist.sample(seed=maybe_seed(seed)))
        except NotImplementedError as e:
            # TODO(b/34701635): Binomial needs a sampler.
            if 'sample_n is not implemented: Binomial' in str(e):
                return
            raise
        except tf.errors.UnimplementedError as e:
            if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(
                    e):
                # Some cases can fail with 'Unhandled input dimensions \d+' or
                # 'inputs rank not in [0,6]: \d+'
                return
            raise

        # Come up with the slices for samples (which must also include event dims).
        sample_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        if Ellipsis not in sample_slices:
            sample_slices += (Ellipsis, )
        sample_slices += tuple([slice(None)] *
                               tensorshape_util.rank(dist.event_shape))

        # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
        data.draw(hps.just(samples[sample_slices]))
        self.assertAllEqual(samples[sample_slices].shape, sliced_samples.shape)
        try:
            try:
                lp = self.evaluate(dist.log_prob(samples))
            except tf.errors.InvalidArgumentError:
                # TODO(b/129271256): d.log_prob(d.sample()) should not fail
                #     validate_args checks.
                # We only tolerate this case for the non-sliced dist.
                return
            sliced_lp = self.evaluate(
                sliced_dist.log_prob(samples[sample_slices]))
        except tf.errors.UnimplementedError as e:
            if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(
                    e):
                # Some cases can fail with 'Unhandled input dimensions \d+' or
                # 'inputs rank not in [0,6]: \d+'
                return
            raise
        # TODO(b/128708201): Better numerics for Geometric/Beta?
        # Eigen can return quite different results for packet vs non-packet ops.
        # To work around this, we use a much larger rtol for the last 3
        # (assuming packet size 4) elements.
        packetized_lp = lp[slices].reshape(-1)[:-3]
        packetized_sliced_lp = sliced_lp.reshape(-1)[:-3]
        rtol = (0.1 if any(x in dist.name for x in ('Geometric', 'Beta',
                                                    'Dirichlet')) else 0.02)
        self.assertAllClose(packetized_lp, packetized_sliced_lp, rtol=rtol)
        possibly_nonpacket_lp = lp[slices].reshape(-1)[-3:]
        possibly_nonpacket_sliced_lp = sliced_lp.reshape(-1)[-3:]
        rtol = 0.4
        self.assertAllClose(possibly_nonpacket_lp,
                            possibly_nonpacket_sliced_lp,
                            rtol=rtol)

    def _run_test(self, data):
        tf.compat.v1.set_random_seed(  # TODO(b/129287396): drop the int(..)
            int(
                data.draw(
                    hpnp.arrays(dtype=np.int64,
                                shape=[]).filter(lambda x: x != 0))))
        # TODO(b/128974935): Avoid passing in data.draw using hps.composite
        # dist, batch_shape = data.draw(distributions())
        dist, batch_shape = distributions(data.draw)
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.assertAllEqual(batch_shape, dist.batch_shape)

        with self.assertRaisesRegexp(TypeError, 'not iterable'):
            iter(dist)  # __getitem__ magically makes an object iterable.

        self._test_slicing(data, dist, batch_shape)

        # TODO(bjp): Enable sampling and log_prob checks. Currently, too many errors
        #     from out-of-domain samples.
        # self.evaluate(dist.log_prob(dist.sample()))

    @hp.given(hps.data())
    @hp.settings(deadline=None,
                 max_examples=hypothesis_max_examples(),
                 suppress_health_check=[hp.HealthCheck.too_slow],
                 derandomize=tfp_test_util.derandomize_hypothesis())
    def testDistributions(self, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return
        self._run_test(data)
class DistributionParamsAreVarsTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters((dname, ) for dname in TF2_FRIENDLY_DISTS)
    @hp.given(hps.data())
    @hp.settings(deadline=None,
                 max_examples=hypothesis_max_examples(),
                 suppress_health_check=[hp.HealthCheck.too_slow],
                 derandomize=tfp_test_util.derandomize_hypothesis())
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        tf.compat.v1.set_random_seed(
            data.draw(
                hpnp.arrays(dtype=np.int64,
                            shape=[]).filter(lambda x: x != 0)))
        dist, batch_shape = data.draw(
            distributions(dist_name=dist_name, enable_vars=True))
        del batch_shape
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.evaluate([var.initializer for var in dist.variables])
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_mutable(v):
                continue
            try:
                self.assertIs(getattr(dist, k), v)
            except AssertionError as e:
                raise AssertionError(
                    'No attr found for parameter {} of distribution {}: \n{}'.
                    format(k, dist_name, e))
            stat = data.draw(
                hps.one_of(
                    map(hps.just,
                        ['mean', 'mode', 'variance', 'covariance', 'entropy'
                         ])))
            try:
                VAR_USAGES.clear()
                getattr(dist, stat)()
                var_nusages = {
                    var: len(usages)
                    for var, usages in VAR_USAGES.items()
                }
                max_permissible = 2  # TODO(jvdillon): Reduce this to 1.
                if any(
                        len(usages) > max_permissible
                        for usages in VAR_USAGES.values()):
                    for var, usages in six.iteritems(VAR_USAGES):
                        if len(usages) > max_permissible:
                            print(
                                'While executing statistic `{}` of `{}`, detected {} '
                                'Tensor conversions for `{}`:'.format(
                                    stat, dist, len(usages), var))
                            for i, usage in enumerate(usages):
                                print('Conversion {} of {}:\n{}'.format(
                                    i + 1, len(usages), ''.join(usage)))
                    raise AssertionError(
                        'Excessive tensor conversions detected for {} {}: {}'.
                        format(dist_name, stat, var_nusages))
            except NotImplementedError:
                pass

        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            with tf.GradientTape() as tape:
                samp = dist.sample()
            grads = tape.gradient(samp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var, dist_name))

        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            # Turn off validations, since log_prob can choke on dist's own samples.
            dist = dist.copy(validate_args=False)
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(dist.sample()))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))