Esempio n. 1
0
    def testSample(self, dist_name, data):
        if (dist_name in self.sample_blocklist) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        self.map(_sample)(random.split(seed, self.batch_size))
Esempio n. 2
0
    def testSample(self, dist_name, data):
        if (dist_name in VMAP_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        jax.vmap(_sample)(random.split(seed, 10))
Esempio n. 3
0
 def testLogProb(self, dist_name, data):
   if (dist_name in JIT_LOGPROB_BLOCKLIST) != FLAGS.blocklists_only:
     self.skipTest('Distribution currently broken.')
   dist = data.draw(dhps.distributions(
       enable_vars=False,
       dist_name=dist_name,
       eligibility_filter=lambda dname: dname not in JIT_LOGPROB_BLOCKLIST))
   sample = dist.sample(seed=test_util.test_seed())
   result = jax.jit(dist.log_prob)(sample)
   if not FLAGS.execute_only:
     self.assertAllClose(dist.log_prob(sample), result,
                         rtol=1e-6, atol=1e-6)
Esempio n. 4
0
    def testSample(self, dist_name, data):
        if (dist_name in JIT_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only:
            self.skipTest('Distribution currently broken.')
        dist = data.draw(
            dhps.distributions(enable_vars=False, dist_name=dist_name))

        def _sample(seed):
            return dist.sample(seed=seed)

        seed = test_util.test_seed()
        result = jax.jit(_sample)(seed)
        if not FLAGS.execute_only:
            self.assertAllClose(_sample(seed), result, rtol=1e-6, atol=1e-6)
Esempio n. 5
0
  def testInputOutputOfJittedFunction(self, dist_name, data):

    if (dist_name in PYTREE_BLOCKLIST) != FLAGS.blocklists_only:
      self.skipTest('Distribution currently broken.')

    @jax.jit
    def dist_and_sample(dist):
      return dist, dist.sample(seed=test_util.test_seed())

    dist = data.draw(dhps.distributions(
        enable_vars=False,
        dist_name=dist_name,
        validate_args=False,
        eligibility_filter=lambda dname: dname not in PYTREE_BLOCKLIST))
    dist_and_sample(dist)
Esempio n. 6
0
 def testLogProb(self, dist_name, data):
   if (dist_name in self.logprob_blocklist) != FLAGS.blocklists_only:
     self.skipTest('Distribution currently broken.')
   if dist_name == 'NegativeBinomial':
     self.skipTest('Skip never-terminating negative binomial vmap logprob.')
   dist = data.draw(dhps.distributions(
       enable_vars=False,
       dist_name=dist_name,
       eligibility_filter=lambda dname: dname not in self.logprob_blocklist))
   sample = dist.sample(seed=test_util.test_seed(),
                        sample_shape=self.batch_size)
   result = self.map(dist.log_prob)(sample)
   if not FLAGS.execute_only:
     self.assertAllClose(result, dist.log_prob(sample),
                         rtol=1e-6, atol=1e-6)
Esempio n. 7
0
  def testFlattenUnflatten(self, dist_name, data):

    if (dist_name in PYTREE_BLOCKLIST) != FLAGS.blocklists_only:
      self.skipTest('Distribution currently broken.')

    dist = data.draw(dhps.distributions(
        enable_vars=False,
        dist_name=dist_name,
        validate_args=False,
        eligibility_filter=lambda dname: dname not in PYTREE_BLOCKLIST))
    flat_dist, dist_tree = jax.tree_util.tree_flatten(dist)
    new_dist = jax.tree_util.tree_unflatten(dist_tree, flat_dist)
    for old_param, new_param in zip(
        flat_dist, jax.tree_util.tree_leaves(new_dist)):
      self.assertEqual(type(old_param), type(new_param))
      if isinstance(old_param, np.ndarray):
        self.assertTupleEqual(old_param.shape, new_param.shape)