def update_thread_local_jit_state(**kw): tls = jax_jit.thread_local_state() # After xla_client._version >= 70, the thread_local object will necessarily # be initialized when accessed. The following line can be removed when the # minimum jaxlib version is past version 70 context = tls.extra_jit_context or _ThreadLocalExtraJitContext() tls.extra_jit_context = context._replace(**kw)
def update_thread_local_jit_state(**kw): tls = jax_jit.thread_local_state() context = tls.extra_jit_context or ThreadLocalJitState() tls.extra_jit_context = context._replace(**kw)