Пример #1
0
    def testSample(self, dist_name, data):
        if (dist_name in self.sample_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        self.map(_sample)(random.split(seed, self.batch_size))
Пример #2
0
    def testSample(self, dist_name, data):
        if (dist_name in VMAP_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        jax.vmap(_sample)(random.split(seed, 10))
Пример #3
0
 def testLogProb(self, dist_name, data):
   if (dist_name in JIT_LOGPROB_BLOCKLIST) != FLAGS.blocklists_only:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(
       enable_vars=False,
       dist_name=dist_name,
       eligibility_filter=lambda dname: dname not in JIT_LOGPROB_BLOCKLIST))
   sample = dist.sample(seed=test_util.test_seed())
   result = jax.jit(dist.log_prob)(sample)
   if not FLAGS.execute_only:
     self.assertAllClose(dist.log_prob(sample), result,
                         rtol=1e-6, atol=1e-6)
Пример #4
0
    def testSample(self, dist_name, data):
        if (dist_name in JIT_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        result = jax.jit(_sample)(seed)
        if not FLAGS.execute_only:
            self.assertAllClose(_sample(seed), result, rtol=1e-6, atol=1e-6)
Пример #5
0
 def testLogProb(self, dist_name, data):
   if (dist_name in self.logprob_blocklist) != FLAGS.blocklists_only:
     self.skipTest('Distribution currently broken.')
   if dist_name == 'NegativeBinomial':
     self.skipTest('Skip never-terminating negative binomial vmap logprob.')
   dist = data.draw(dhps.distributions(
       enable_vars=False,
       dist_name=dist_name,
       eligibility_filter=lambda dname: dname not in self.logprob_blocklist))
   sample = dist.sample(seed=test_util.test_seed(),
                        sample_shape=self.batch_size)
   result = self.map(dist.log_prob)(sample)
   if not FLAGS.execute_only:
     self.assertAllClose(result, dist.log_prob(sample),
                         rtol=1e-6, atol=1e-6)
Пример #6
0
 def dist_and_sample(dist):
     return dist, dist.sample(seed=test_util.test_seed())