def test_mixed_nesting(self): with jax.transfer_guard_host_to_device("disallow"): with jax.transfer_guard("allow"): with self.assertAllows("host_to_device_jnp_ones"): jnp.ones(1) with self.assertDisallows("host_to_device_jnp_ones"): jnp.ones(1) with jax.transfer_guard("disallow"): with jax.transfer_guard_host_to_device("allow"): with self.assertAllows("host_to_device_jnp_ones"): jnp.ones(1) with self.assertDisallows("host_to_device_jnp_ones"): jnp.ones(1)
def _device_to_device_funcs(): """Generates device-to-device transfer functions.""" if len(jax.local_devices()) < 2: # device-to-device tests require at least 2 devices. return [] with jax.transfer_guard_host_to_device("allow"): device_arrays = [jnp.ones(1) for _ in range(2)] return [ # (function name, is an explicit transfer?, function) ("device_to_device_jax_device_put", True, lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])), ("device_to_device_jax_jit", False, lambda: jax.jit(lambda x: x, device=jax.local_devices()[1]) (device_arrays[1])), ]
def _device_to_host_funcs(): """Generates device-to-host transfer functions.""" if jax.default_backend() == "cpu": # device-to-host does not incur transfer on the CPU backend. return [] with jax.transfer_guard_host_to_device("allow"): device_arrays = [jnp.ones(1) for _ in range(6)] return [ # (function name, is an explicit transfer?, function) ("device_to_host_jax_device_get", True, lambda: jax.device_get(device_arrays[0])), ("device_to_host_np_asarray", False, lambda: np.asarray(device_arrays[1])), ("device_to_host_copy_to_host_async", False, lambda: device_arrays[2].copy_to_host_async()), ("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)), ("device_to_host_str", False, lambda: str(device_arrays[4])), ("device_to_host_pickle_dumps", False, lambda: pickle.dumps(device_arrays[5])), ]