def _get_backend():
    return xla_client.BackendSpec(_backend_flag_to_type[FLAGS.jax_xla_backend],
                                  FLAGS.jax_backend_target)
Example #2
0
_backends = {}

def register_backend(name, factory):
  _backends[name] = factory


if hasattr(xla_client, 'XlaLocalBackend'):
  register_backend('xla', lambda: xla_client.XlaLocalBackend())
  register_backend('xrt',
                   lambda: xla_client.XrtBackend(FLAGS.jax_backend_target))
else:
  # TODO(phawkins): this case is for cross-version compatibility. Delete this
  # case after a Jaxlib update.
  register_backend(
    'xla', lambda: xla_client.BackendSpec(xla_client.BackendType.XLA_LOCAL, ''))
  register_backend(
    'xrt', lambda: xla_client.BackendSpec(xla_client.BackendType.XRT,
                                          FLAGS.jax_backend_target))


@memoize_thunk
def _get_backend():
  backend = _backends.get(FLAGS.jax_xla_backend)
  if backend is None:
    msg = 'Unknown jax_xla_backend value "{}".'
    raise ValueError(msg.format(FLAGS.jax_xla_backend))
  return backend()


def device_put(pyval, replica=0):