def testLogProb(self, dist_name, data): if dist_name in VMAP_LOGPROB_BLACKLIST: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions(enable_vars=False, dist_name=dist_name)) sample = dist.sample(seed=test_util.test_seed(), sample_shape=10) self.assertAllClose(jax.vmap(dist.log_prob)(sample), dist.log_prob(sample), rtol=1e-6, atol=1e-6)
def testSample(self, dist_name, data): if dist_name in VMAP_SAMPLE_BLACKLIST: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() jax.vmap(_sample)(random.split(seed, 10))
def testLogProb(self, dist_name, data): if dist_name in VMAP_LOGPROB_BLOCKLIST and not FLAGS.ignore_blocklists: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions(enable_vars=False, dist_name=dist_name)) sample = dist.sample(seed=test_util.test_seed(), sample_shape=10) result = jax.vmap(dist.log_prob)(sample) if not FLAGS.execute_only: self.assertAllClose(result, dist.log_prob(sample), rtol=1e-6, atol=1e-6)
def testSample(self, dist_name, data): if dist_name in JIT_SAMPLE_BLACKLIST: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() self.assertAllClose(_sample(seed), jax.jit(_sample)(seed), rtol=1e-6, atol=1e-6)
def testSample(self, dist_name, data): if dist_name in JIT_SAMPLE_BLOCKLIST and not FLAGS.ignore_blocklists: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() result = jax.jit(_sample)(seed) if not FLAGS.execute_only: self.assertAllClose(_sample(seed), result, rtol=1e-6, atol=1e-6)