コード例 #1
0
    def test_jit_on_nondefault_backend(self):
        cpus = api.devices("cpu")
        self.assertNotEmpty(cpus)

        # Since we are not on CPU, some other backend will be the default
        default_dev = api.devices()[0]
        self.assertNotEqual(default_dev.platform, "cpu")

        data_on_cpu = api.device_put(1, device=cpus[0])
        self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])

        def my_sin(x):
            return jnp.sin(x)

        # jit without any device spec follows the data
        result1 = api.jit(my_sin)(2)
        self.assertEqual(result1.device_buffer.device(), default_dev)
        result2 = api.jit(my_sin)(data_on_cpu)
        self.assertEqual(result2.device_buffer.device(), cpus[0])

        # jit with `device` spec places the data on the specified device
        result3 = api.jit(my_sin, device=cpus[0])(2)
        self.assertEqual(result3.device_buffer.device(), cpus[0])

        # jit with `backend` spec places the data on the specified backend
        result4 = api.jit(my_sin, backend="cpu")(2)
        self.assertEqual(result4.device_buffer.device(), cpus[0])
コード例 #2
0
 def test_closed_over_values_device_placement(self):
   # see https://github.com/google/jax/issues/1431
   def f(): return np.add(3., 4.)
   self.assertNotEqual(api.jit(f)().device_buffer.device(),
                       api.devices('cpu')[0])
   self.assertEqual(api.jit(f, backend='cpu')().device_buffer.device(),
                    api.devices('cpu')[0])
コード例 #3
0
    def testJitCpu(self):
        @partial(api.jit, backend='cpu')
        def get_arr(scale):
            return scale + np.ones((2, 2))

        x = get_arr(0.1)

        a = x / x.shape[0]
        b = x + np.ones_like(x)
        c = x + np.eye(2)

        self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0])
        self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0])
        self.assertEqual(c.device_buffer.device(), api.devices('cpu')[0])
コード例 #4
0
    def test_sum(self):
        # https://github.com/google/jax/issues/2905
        cpus = api.devices("cpu")

        x = api.device_put(np.ones(2), cpus[0])
        y = x.sum()
        self.assertEqual(y.device_buffer.device(), cpus[0])
コード例 #5
0
ファイル: host_callback_test.py プロジェクト: stilling/jax
 def helper_set_devices(self, nr_devices):
     flags_str = os.getenv("XLA_FLAGS", "")
     os.environ["XLA_FLAGS"] = (
         flags_str +
         " --xla_force_host_platform_device_count={}".format(nr_devices))
     # Clear any cached backends so new CPU backend will pick up the env var.
     xla_bridge.get_backend.cache_clear()
     return api.devices()
コード例 #6
0
def outfeed_receiver(*,
                     timeout_sec=10,
                     backends: Optional[Sequence[str]] = None,
                     devices: Optional[Sequence[XlaDevice]] = None,
                     receiver_name=""):
    # TODO: better timeout management.
    """Starts receivers for the :func:`id_tap` outfeed from several devices.

  The receivers will run in a threadpool. The tapped functions will be invoked
  in those threads. If a tap function raises an exception, an error is
  printed, but the receiving continues until the body of the context manager
  terminates and all outfeeds from all devices have been received. Only then
  will a :exc:`TapFunctionException` be raised.

  Args:
    backends: (optional) sequence of backend names for which to listen.
      Will listen to all devices on those backends. By default, listed to
      all devices on all known backends.
    devices: (optional) sequence of devices to listed to. At most one
      of `backends` or `devices` must be given.
    receiver_name: (optional) a name to use with debug logging
  Usage::

    with outfeed_receiver():
      jax.jit(func)(args)
      ...
      jax.pmap(another_func)(args)

  The ``outfeed_receiver`` must be started outside any jitted computation.

  """
    if not devices:
        backends = backends or xla_client._get_local_backends().keys()
        devices = tuple(
            itertools.chain(*[api.devices(backend) for backend in backends]))
    else:
        if backends:
            raise ValueError(
                "At most one of `devices` or `backends` must be given.")
    executor = futures.ThreadPoolExecutor(
        thread_name_prefix=f"outfeed_receiver_{receiver_name}",
        max_workers=len(devices))

    count_tap_exceptions = 0

    def device_receiver_loop(device: XlaDevice) -> XlaDevice:
        """Polls the outfeed for a device in a loop."""
        nonlocal count_tap_exceptions
        while (True):
            consumer_id, arrays = _receive_outfeed(device, receiver_name)
            if _LOGGING:
                logging.info(
                    f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} "
                    + (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
            if consumer_id == _end_consumer:
                assert not arrays
                if _LOGGING:
                    logging.info(
                        f"[{receiver_name}:{device}] Outfeed received END_OUTFEED"
                    )
                return device
            consumer = _consumer_registry_by_id.get(consumer_id)
            if consumer is None:
                logging.error(
                    f"Ignoring received outfeed for unknown tap consumer")
                count_tap_exceptions += 1
                continue  # We need to read the entire outfeed
            try:
                arg = api.tree_unflatten(consumer.arg_treedef, arrays)
                consumer.func(arg, **dict(
                    consumer.kwargs))  # type: ignore[attribute-error]
            except Exception as e:
                logging.error(
                    f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}"
                )
                count_tap_exceptions += 1
                # We continue for now, we need to keep reading the outfeed

    receiver_futures = [
        executor.submit(device_receiver_loop, d) for d in devices
    ]
    # Register a callback to raise errors if any. These exception come from
    # bugs in our code, not from the tap functions.
    for rf in receiver_futures:
        rf.add_done_callback(lambda rf: rf.result())
    global _outfeed_receiver_started
    if _outfeed_receiver_started:
        raise ValueError(
            "At most one outfeed_receiver can be running at once.")
    _outfeed_receiver_started = True
    xla.can_execute_outfeed_computations = True
    try:
        yield
    finally:
        for d in devices:  # Signal the end of printing
            api.jit(lambda x: id_tap(_end_consumer, None, result=x),
                    device=d)(0)  # type: ignore[arg-type]
        xla.can_execute_outfeed_computations = False
        _outfeed_receiver_started = False
        for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
            finished_device = f.result()  # Throw exceptions here
            if _LOGGING:
                logging.info(
                    f"[{receiver_name}:{finished_device} Outfeed receiver finished"
                )
        if count_tap_exceptions > 0:
            raise TapFunctionException
コード例 #7
0
 def kernel_fn(x1, x2=None, *args, **kwargs):
   return device_put(_kernel_fn(x1, x2, *args, **kwargs), devices('cpu')[0])