예제 #1
0
def _maybe_jit(jit_type, func, *args, **kwargs):
  if jit_type == "python":
    return api._python_jit(func, *args, **kwargs)
  elif jit_type == "cpp":
    return api._cpp_jit(func, *args, **kwargs)
  elif jit_type is None:
    return func
  else:
    raise ValueError(f"Unrecognized jit_type={jit_type!r}")
예제 #2
0
파일: jax_jit_test.py 프로젝트: wayfeng/jax
    def test_signature_support(self):
        def f(a, b, c):
            return a + b + c

        jitted_f = api._cpp_jit(f)
        self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))