def testLogProb(self, dist_name, data):
   if dist_name in VMAP_LOGPROB_BLACKLIST:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(enable_vars=False,
                                       dist_name=dist_name))
   sample = dist.sample(seed=test_util.test_seed(), sample_shape=10)
   self.assertAllClose(jax.vmap(dist.log_prob)(sample), dist.log_prob(sample),
                       rtol=1e-6, atol=1e-6)
 def testSample(self, dist_name, data):
   if dist_name in VMAP_SAMPLE_BLACKLIST:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(enable_vars=False,
                                       dist_name=dist_name))
   def _sample(seed):
     return dist.sample(seed=seed)
   seed = test_util.test_seed()
   jax.vmap(_sample)(random.split(seed, 10))
示例#3
0
 def testLogProb(self, dist_name, data):
   if dist_name in VMAP_LOGPROB_BLOCKLIST and not FLAGS.ignore_blocklists:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(enable_vars=False,
                                       dist_name=dist_name))
   sample = dist.sample(seed=test_util.test_seed(), sample_shape=10)
   result = jax.vmap(dist.log_prob)(sample)
   if not FLAGS.execute_only:
     self.assertAllClose(result, dist.log_prob(sample),
                         rtol=1e-6, atol=1e-6)
 def testSample(self, dist_name, data):
   if dist_name in JIT_SAMPLE_BLACKLIST:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(enable_vars=False,
                                       dist_name=dist_name))
   def _sample(seed):
     return dist.sample(seed=seed)
   seed = test_util.test_seed()
   self.assertAllClose(_sample(seed), jax.jit(_sample)(seed), rtol=1e-6,
                       atol=1e-6)
示例#5
0
 def testSample(self, dist_name, data):
   if dist_name in JIT_SAMPLE_BLOCKLIST and not FLAGS.ignore_blocklists:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(enable_vars=False,
                                       dist_name=dist_name))
   def _sample(seed):
     return dist.sample(seed=seed)
   seed = test_util.test_seed()
   result = jax.jit(_sample)(seed)
   if not FLAGS.execute_only:
     self.assertAllClose(_sample(seed), result, rtol=1e-6,
                         atol=1e-6)