Esempio n. 1
0
    def test_destructure(self):
        def d(key):
            key1, key2 = key
            return key1

        self.check(d, ['2'], '', {}, [(2, )], ['int_'],
                   jtu.rand_int(self.rng(), 0, 10))
Esempio n. 2
0
  def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
    """Tests against JIT compatibility and Numpy."""
    n_max = l_max
    shape = (num_z,)
    rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

    lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

    def args_maker():
      m = rng(shape, dtype)
      n = abs(m)
      theta = jnp.linspace(-4.0, 5.0, num_z)
      phi = jnp.linspace(-2.0, 1.0, num_z)
      return m, n, theta, phi

    with self.subTest('Test JIT compatibility'):
      self._CompileAndCheck(lsp_special_fn, args_maker)

    with self.subTest('Test against numpy.'):
      self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)