def _get_backend(): return xla_client.BackendSpec(_backend_flag_to_type[FLAGS.jax_xla_backend], FLAGS.jax_backend_target)
_backends = {} def register_backend(name, factory): _backends[name] = factory if hasattr(xla_client, 'XlaLocalBackend'): register_backend('xla', lambda: xla_client.XlaLocalBackend()) register_backend('xrt', lambda: xla_client.XrtBackend(FLAGS.jax_backend_target)) else: # TODO(phawkins): this case is for cross-version compatibility. Delete this # case after a Jaxlib update. register_backend( 'xla', lambda: xla_client.BackendSpec(xla_client.BackendType.XLA_LOCAL, '')) register_backend( 'xrt', lambda: xla_client.BackendSpec(xla_client.BackendType.XRT, FLAGS.jax_backend_target)) @memoize_thunk def _get_backend(): backend = _backends.get(FLAGS.jax_xla_backend) if backend is None: msg = 'Unknown jax_xla_backend value "{}".' raise ValueError(msg.format(FLAGS.jax_xla_backend)) return backend() def device_put(pyval, replica=0):