示例#1
0
    def testLogProbParam(self, dist_name, data):
        if (dist_name
                in self.logprob_param_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')

        params, batch_shape = data.draw(
            dhps.base_distribution_unconstrained_params(enable_vars=False,
                                                        dist_name=dist_name))
        constrained_params = dhps.constrain_params(params, dist_name)

        sampling_dist = data.draw(
            dhps.base_distributions(batch_shape=batch_shape,
                                    enable_vars=False,
                                    dist_name=dist_name,
                                    params=constrained_params))
        sample = sampling_dist.sample(seed=random.PRNGKey(0))

        def _log_prob(dist):
            return dist.log_prob(sample)

        for param_name, param, dist_func, func in self._param_func_generator(
                data, dist_name, params, batch_shape, _log_prob):
            del dist_func
            self._test_transformation(functools.partial(func, param_name),
                                      param,
                                      msg=param_name)
示例#2
0
    def testLogProbSample(self, dist_name, data):
        if (dist_name
                in self.logprob_sample_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')

        params, batch_shape = data.draw(
            dhps.base_distribution_unconstrained_params(enable_vars=False,
                                                        dist_name=dist_name))
        constrained_params = dhps.constrain_params(params, dist_name)

        dist = data.draw(
            dhps.base_distributions(batch_shape=batch_shape,
                                    enable_vars=False,
                                    dist_name=dist_name,
                                    params=constrained_params))

        sample = dist.sample(seed=random.PRNGKey(0))
        if np.issubdtype(sample.dtype, np.integer):
            self.skipTest(
                '{} has integer samples; no derivative.'.format(dist_name))

        def _log_prob(sample):
            return dist.log_prob(sample)

        self._test_transformation(_log_prob, sample)
示例#3
0
    def testSample(self, dist_name, data):
        if (dist_name in self.sample_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')

        def _sample(dist):
            return dist.sample(seed=random.PRNGKey(0))

        params_unconstrained, batch_shape = data.draw(
            dhps.base_distribution_unconstrained_params(enable_vars=False,
                                                        dist_name=dist_name))

        for (param_name, unconstrained_param, dist_func,
             func) in self._param_func_generator(data, dist_name,
                                                 params_unconstrained,
                                                 batch_shape, _sample):
            dist = dist_func(param_name, unconstrained_param)
            if (dist.reparameterization_type !=
                    reparameterization.FULLY_REPARAMETERIZED):
                # Skip distributions that don't support differentiable sampling.
                self.skipTest('{} is not reparameterized.'.format(dist_name))
            self._test_transformation(functools.partial(func, param_name),
                                      unconstrained_param,
                                      msg=param_name)