def testSample(self, dist_name, data): if (dist_name in self.sample_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') dist = data.draw( dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() self.map(_sample)(random.split(seed, self.batch_size))
def testSample(self, dist_name, data): if (dist_name in VMAP_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') dist = data.draw( dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() jax.vmap(_sample)(random.split(seed, 10))
def testLogProb(self, dist_name, data): if (dist_name in JIT_LOGPROB_BLOCKLIST) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions( enable_vars=False, dist_name=dist_name, eligibility_filter=lambda dname: dname not in JIT_LOGPROB_BLOCKLIST)) sample = dist.sample(seed=test_util.test_seed()) result = jax.jit(dist.log_prob)(sample) if not FLAGS.execute_only: self.assertAllClose(dist.log_prob(sample), result, rtol=1e-6, atol=1e-6)
def testSample(self, dist_name, data): if (dist_name in JIT_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') dist = data.draw( dhps.distributions(enable_vars=False, dist_name=dist_name)) def _sample(seed): return dist.sample(seed=seed) seed = test_util.test_seed() result = jax.jit(_sample)(seed) if not FLAGS.execute_only: self.assertAllClose(_sample(seed), result, rtol=1e-6, atol=1e-6)
def testInputOutputOfJittedFunction(self, dist_name, data): if (dist_name in PYTREE_BLOCKLIST) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') @jax.jit def dist_and_sample(dist): return dist, dist.sample(seed=test_util.test_seed()) dist = data.draw(dhps.distributions( enable_vars=False, dist_name=dist_name, validate_args=False, eligibility_filter=lambda dname: dname not in PYTREE_BLOCKLIST)) dist_and_sample(dist)
def testLogProb(self, dist_name, data): if (dist_name in self.logprob_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') if dist_name == 'NegativeBinomial': self.skipTest('Skip never-terminating negative binomial vmap logprob.') dist = data.draw(dhps.distributions( enable_vars=False, dist_name=dist_name, eligibility_filter=lambda dname: dname not in self.logprob_blocklist)) sample = dist.sample(seed=test_util.test_seed(), sample_shape=self.batch_size) result = self.map(dist.log_prob)(sample) if not FLAGS.execute_only: self.assertAllClose(result, dist.log_prob(sample), rtol=1e-6, atol=1e-6)
def testFlattenUnflatten(self, dist_name, data): if (dist_name in PYTREE_BLOCKLIST) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') dist = data.draw(dhps.distributions( enable_vars=False, dist_name=dist_name, validate_args=False, eligibility_filter=lambda dname: dname not in PYTREE_BLOCKLIST)) flat_dist, dist_tree = jax.tree_util.tree_flatten(dist) new_dist = jax.tree_util.tree_unflatten(dist_tree, flat_dist) for old_param, new_param in zip( flat_dist, jax.tree_util.tree_leaves(new_dist)): self.assertEqual(type(old_param), type(new_param)) if isinstance(old_param, np.ndarray): self.assertTupleEqual(old_param.shape, new_param.shape)