def make_array_from_callback(shape: Shape, sharding: Sharding, data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array: arrays = [ device_put(data_callback(sharding.device_indices(device, shape)), device) for device in sharding.addressable_devices ] return Array(shape, sharding, arrays, committed=True)
def test_sum(self): # https://github.com/google/jax/issues/2905 cpus = api.devices("cpu") x = api.device_put(np.ones(2), cpus[0]) y = x.sum() self.assertEqual(y.device_buffer.device(), cpus[0])
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, data_callback: Callable[[Index], ArrayLike]): """Constructs a GlobalDeviceArray via data fetched from ``data_callback``. ``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray. Example:: global_input_shape = (8, 2) global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) Args: global_shape : The global shape of the array global_mesh : The global mesh representing devices across multiple processes. mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray. data_callback : Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a ``numpy.ndarray``. """ global_indices_rid = get_shard_indices_replica_ids( global_shape, global_mesh, mesh_axes) local_devices = global_mesh.local_devices dbs = [ device_put(data_callback(global_indices_rid[device][0]), device) for device in local_devices ] return cls(global_shape, global_mesh, mesh_axes, dbs, _gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
def test_jit_on_nondefault_backend(self): cpus = api.devices("cpu") self.assertNotEmpty(cpus) # Since we are not on CPU, some other backend will be the default default_dev = api.devices()[0] self.assertNotEqual(default_dev.platform, "cpu") data_on_cpu = api.device_put(1, device=cpus[0]) self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0]) def my_sin(x): return jnp.sin(x) # jit without any device spec follows the data result1 = api.jit(my_sin)(2) self.assertEqual(result1.device_buffer.device(), default_dev) result2 = api.jit(my_sin)(data_on_cpu) self.assertEqual(result2.device_buffer.device(), cpus[0]) # jit with `device` spec places the data on the specified device result3 = api.jit(my_sin, device=cpus[0])(2) self.assertEqual(result3.device_buffer.device(), cpus[0]) # jit with `backend` spec places the data on the specified backend result4 = api.jit(my_sin, backend="cpu")(2) self.assertEqual(result4.device_buffer.device(), cpus[0])
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, data_callback: Callable[[Index], ArrayLike]): indices = get_shard_indices(global_shape, global_mesh, mesh_axes) dbs = [ device_put(data_callback(indices[device]), device) for device in global_mesh.local_devices ] return cls(global_shape, global_mesh, mesh_axes, dbs)
def testBooleanIndexingArray2D(self): idx = np.array([[True, False], [False, True], [False, False], [True, True]]) x = np.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False)
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, data_callback: Callable[[Index], ArrayLike]): """Constructs a GlobalDeviceArray via data fetched from ``data_callback``. ``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray. Example: >>> from jax.experimental.maps import Mesh >>> from jax.experimental import PartitionSpec as P >>> import numpy as np ... >>> global_input_shape = (8, 8) >>> mesh_axes = P('x', 'y') >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) ... >>> def cb(index): ... return global_input_data[index] ... >>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) >>> gda.local_data(0).shape (4, 2) Args: global_shape : The global shape of the array global_mesh : The global mesh representing devices across multiple processes. mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray. data_callback : Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a ``numpy.ndarray``. """ global_indices_rid = get_shard_indices_replica_ids( global_shape, global_mesh, mesh_axes) local_devices = global_mesh.local_devices dbs = [ device_put(data_callback(global_indices_rid[device][0]), device) for device in local_devices ] return cls(global_shape, global_mesh, mesh_axes, dbs, _gda_fast_path_args=_GdaFastPathArgs( global_indices_rid, local_devices))
def _asarray(a): # simplified version of jnp.asarray() for local use. return a if isinstance(a, ndarray) else api.device_put(a)
def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = np.arange(8).reshape(4, 2) with self.assertRaisesRegex(TypeError, ARRAY_MSG): api.device_put(x)[idx]
def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(np.arange(3)) with self.assertRaisesRegex(TypeError, ARRAY_MSG): x[idx]
def testBooleanIndexingArray1D(self): idx = np.array([True, True, False]) x = api.device_put(np.arange(3)) ans = x[idx] expected = np.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False)