def _maybe_jit(jit_type, func, *args, **kwargs): if jit_type == "python": return api._python_jit(func, *args, **kwargs) elif jit_type == "cpp": return api._cpp_jit(func, *args, **kwargs) elif jit_type is None: return func else: raise ValueError(f"Unrecognized jit_type={jit_type!r}")
def test_signature_support(self): def f(a, b, c): return a + b + c jitted_f = api._cpp_jit(f) self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))