Exemple #1
0
def get_backend(platform=None):
    # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
    # 'backend' values are handled
    if not isinstance(platform, (type(None), str)):
        return platform

    with _backend_lock:
        backend_factory = _backends.get(FLAGS.jax_xla_backend)
        if backend_factory is None:
            msg = 'Unknown jax_xla_backend value "{}".'
            raise ValueError(msg.format(FLAGS.jax_xla_backend))
        backend = backend_factory(platform)
        util.distributed_debug_log(("Initialized backend", backend.platform),
                                   ("process_index", backend.process_index()),
                                   ("device_count", backend.device_count()),
                                   ("local_devices", backend.local_devices()))
        return backend
Exemple #2
0
def backends():
    global _backends
    global _backends_errors
    global _default_backend

    with _backend_lock:
        if _backends is not None:
            return _backends

        default_priority = -1000
        _backends = {}
        _backends_errors = {}
        for name, (factory, priority) in _backend_factories.items():
            logging.vlog(1, "Initializing backend '%s'" % name)
            try:
                backend = factory()
                if backend is not None:
                    if backend.device_count() > 0:
                        _backends[name] = backend
                    util.distributed_debug_log(
                        ("Initialized backend", backend.platform),
                        ("process_index", backend.process_index()),
                        ("device_count", backend.device_count()),
                        ("local_devices", backend.local_devices()))
                    logging.vlog(1, "Backend '%s' initialized" % name)
                    if priority > default_priority:
                        _default_backend = backend
                        default_priority = priority
            except Exception as err:
                if name in ('cpu', 'interpreter'):
                    # We always expect the CPU and interpreter backends to initialize
                    # successfully.
                    raise
                else:
                    # If the backend isn't built into the binary, or if it has no devices,
                    # we expect a RuntimeError.
                    logging.info("Unable to initialize backend '%s': %s" %
                                 (name, err))
                    _backends_errors[name] = str(err)
                    continue
        if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
            logging.warning(
                'No GPU/TPU found, falling back to CPU. '
                '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
        return _backends
Exemple #3
0
def _init_backend(platform):
  factory, unused_priority = _backend_factories.get(platform, (None, None))
  if factory is None:
    raise RuntimeError(f"Unknown backend '{platform}'")

  logging.vlog(1, "Initializing backend '%s'" % platform)
  backend = factory()
  # TODO(skye): consider raising more descriptive errors directly from backend
  # factories instead of returning None.
  if backend is None:
    raise RuntimeError(f"Could not initialize backend '{platform}'")
  if backend.device_count() == 0:
    raise RuntimeError(f"Backend '{platform}' provides no devices.")
  util.distributed_debug_log(("Initialized backend", backend.platform),
                             ("process_index", backend.process_index()),
                             ("device_count", backend.device_count()),
                             ("local_devices", backend.local_devices()))
  logging.vlog(1, "Backend '%s' initialized" % platform)
  return backend
Exemple #4
0
def backends():
    global _backends

    with _backend_lock:
        if _backends is not None:
            return _backends

        _backends = {}
        for name, factory in _backend_factories.items():
            logging.vlog(1, "Initializing backend '%s'" % name)
            try:
                backend = factory()
                if backend is not None:
                    if backend.device_count() > 0:
                        _backends[name] = backend
                    util.distributed_debug_log(
                        ("Initialized backend", backend.platform),
                        ("process_index", backend.process_index()),
                        ("device_count", backend.device_count()),
                        ("local_devices", backend.local_devices()))
                    logging.vlog(1, "Backend '%s' initialized" % name)
            except (RuntimeError, ImportError) as err:
                if name in ('cpu', 'interpreter'):
                    # We always expect the CPU and interpreter backends to initialize
                    # successfully.
                    raise
                else:
                    # If the backend isn't built into the binary, or if it has no devices,
                    # we expect a RuntimeError.
                    logging.info("Unable to initialize backend '%s': %s" %
                                 (name, err))
                    continue
        if list(_backends.keys())[-1] == "cpu":
            logging.warning(
                'No GPU/TPU found, falling back to CPU. '
                '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
        return _backends