コード例 #1
0
    def testLogProbParam(self, dist_name, data):
        if (dist_name
                in self.logprob_param_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')

        params, batch_shape = data.draw(
            dhps.base_distribution_unconstrained_params(enable_vars=False,
                                                        dist_name=dist_name))
        constrained_params = dhps.constrain_params(params, dist_name)

        sampling_dist = data.draw(
            dhps.base_distributions(batch_shape=batch_shape,
                                    enable_vars=False,
                                    dist_name=dist_name,
                                    params=constrained_params))
        sample = sampling_dist.sample(seed=random.PRNGKey(0))

        def _log_prob(dist):
            return dist.log_prob(sample)

        for param_name, param, dist_func, func in self._param_func_generator(
                data, dist_name, params, batch_shape, _log_prob):
            del dist_func
            self._test_transformation(functools.partial(func, param_name),
                                      param,
                                      msg=param_name)
コード例 #2
0
    def testLogProbSample(self, dist_name, data):
        if (dist_name
                in self.logprob_sample_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')

        params, batch_shape = data.draw(
            dhps.base_distribution_unconstrained_params(enable_vars=False,
                                                        dist_name=dist_name))
        constrained_params = dhps.constrain_params(params, dist_name)

        dist = data.draw(
            dhps.base_distributions(batch_shape=batch_shape,
                                    enable_vars=False,
                                    dist_name=dist_name,
                                    params=constrained_params))

        sample = dist.sample(seed=random.PRNGKey(0))
        if np.issubdtype(sample.dtype, np.integer):
            self.skipTest(
                '{} has integer samples; no derivative.'.format(dist_name))

        def _log_prob(sample):
            return dist.log_prob(sample)

        self._test_transformation(_log_prob, sample)
コード例 #3
0
 def _make_distribution(self, dist_name, params,
                        batch_shape, override_params=None):
   override_params = override_params or {}
   all_params = dict(params)
   for param_name, override_param in override_params.items():
     all_params[param_name] = override_param
   all_params = dhps.constrain_params(all_params, dist_name)
   all_params = dhps.modify_params(all_params, dist_name, validate_args=False)
   return dhps.base_distributions(
       enable_vars=False, dist_name=dist_name, params=all_params,
       batch_shape=batch_shape, validate_args=False)