def test_jit_on_nondefault_backend(self): cpus = api.devices("cpu") self.assertNotEmpty(cpus) # Since we are not on CPU, some other backend will be the default default_dev = api.devices()[0] self.assertNotEqual(default_dev.platform, "cpu") data_on_cpu = api.device_put(1, device=cpus[0]) self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0]) def my_sin(x): return jnp.sin(x) # jit without any device spec follows the data result1 = api.jit(my_sin)(2) self.assertEqual(result1.device_buffer.device(), default_dev) result2 = api.jit(my_sin)(data_on_cpu) self.assertEqual(result2.device_buffer.device(), cpus[0]) # jit with `device` spec places the data on the specified device result3 = api.jit(my_sin, device=cpus[0])(2) self.assertEqual(result3.device_buffer.device(), cpus[0]) # jit with `backend` spec places the data on the specified backend result4 = api.jit(my_sin, backend="cpu")(2) self.assertEqual(result4.device_buffer.device(), cpus[0])
def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return np.add(3., 4.) self.assertNotEqual(api.jit(f)().device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(api.jit(f, backend='cpu')().device_buffer.device(), api.devices('cpu')[0])
def testJitCpu(self): @partial(api.jit, backend='cpu') def get_arr(scale): return scale + np.ones((2, 2)) x = get_arr(0.1) a = x / x.shape[0] b = x + np.ones_like(x) c = x + np.eye(2) self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(c.device_buffer.device(), api.devices('cpu')[0])
def test_sum(self): # https://github.com/google/jax/issues/2905 cpus = api.devices("cpu") x = api.device_put(np.ones(2), cpus[0]) y = x.sum() self.assertEqual(y.device_buffer.device(), cpus[0])
def helper_set_devices(self, nr_devices): flags_str = os.getenv("XLA_FLAGS", "") os.environ["XLA_FLAGS"] = ( flags_str + " --xla_force_host_platform_device_count={}".format(nr_devices)) # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() return api.devices()
def outfeed_receiver(*, timeout_sec=10, backends: Optional[Sequence[str]] = None, devices: Optional[Sequence[XlaDevice]] = None, receiver_name=""): # TODO: better timeout management. """Starts receivers for the :func:`id_tap` outfeed from several devices. The receivers will run in a threadpool. The tapped functions will be invoked in those threads. If a tap function raises an exception, an error is printed, but the receiving continues until the body of the context manager terminates and all outfeeds from all devices have been received. Only then will a :exc:`TapFunctionException` be raised. Args: backends: (optional) sequence of backend names for which to listen. Will listen to all devices on those backends. By default, listed to all devices on all known backends. devices: (optional) sequence of devices to listed to. At most one of `backends` or `devices` must be given. receiver_name: (optional) a name to use with debug logging Usage:: with outfeed_receiver(): jax.jit(func)(args) ... jax.pmap(another_func)(args) The ``outfeed_receiver`` must be started outside any jitted computation. """ if not devices: backends = backends or xla_client._get_local_backends().keys() devices = tuple( itertools.chain(*[api.devices(backend) for backend in backends])) else: if backends: raise ValueError( "At most one of `devices` or `backends` must be given.") executor = futures.ThreadPoolExecutor( thread_name_prefix=f"outfeed_receiver_{receiver_name}", max_workers=len(devices)) count_tap_exceptions = 0 def device_receiver_loop(device: XlaDevice) -> XlaDevice: """Polls the outfeed for a device in a loop.""" nonlocal count_tap_exceptions while (True): consumer_id, arrays = _receive_outfeed(device, receiver_name) if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " + (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) if consumer_id == _end_consumer: assert not arrays if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received END_OUTFEED" ) return device consumer = _consumer_registry_by_id.get(consumer_id) if consumer is None: logging.error( f"Ignoring received outfeed for unknown tap consumer") count_tap_exceptions += 1 continue # We need to read the entire outfeed try: arg = api.tree_unflatten(consumer.arg_treedef, arrays) consumer.func(arg, **dict( consumer.kwargs)) # type: ignore[attribute-error] except Exception as e: logging.error( f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}" ) count_tap_exceptions += 1 # We continue for now, we need to keep reading the outfeed receiver_futures = [ executor.submit(device_receiver_loop, d) for d in devices ] # Register a callback to raise errors if any. These exception come from # bugs in our code, not from the tap functions. for rf in receiver_futures: rf.add_done_callback(lambda rf: rf.result()) global _outfeed_receiver_started if _outfeed_receiver_started: raise ValueError( "At most one outfeed_receiver can be running at once.") _outfeed_receiver_started = True xla.can_execute_outfeed_computations = True try: yield finally: for d in devices: # Signal the end of printing api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type] xla.can_execute_outfeed_computations = False _outfeed_receiver_started = False for f in futures.as_completed(receiver_futures, timeout=timeout_sec): finished_device = f.result() # Throw exceptions here if _LOGGING: logging.info( f"[{receiver_name}:{finished_device} Outfeed receiver finished" ) if count_tap_exceptions > 0: raise TapFunctionException
def kernel_fn(x1, x2=None, *args, **kwargs): return device_put(_kernel_fn(x1, x2, *args, **kwargs), devices('cpu')[0])