def test_device_mismatch(self): devices = jax.devices() if len(devices) < 8: raise unittest.SkipTest("Test requires 8 global devices.") mesh_devices = np.array([[devices[0], devices[2]], [devices[3], devices[1]], [devices[4], devices[6]], [devices[7], devices[5]]]) global_mesh = Mesh(mesh_devices, ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[indices[d]], d) for d in jax.local_devices() ] with self.assertRaisesRegex( ValueError, 'The `global_mesh.local_devices` and `device_buffers` device order' ): GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: # TODO(yashkatariya): Remove this when utilities are moved to pxla.py. from jax.experimental import global_device_array # `get_shard_indices` is cached. return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)
def gda_construction_raw(mesh_shape, mesh_axes, state): # `device_put` time is not measured in this benchmark. All the devices here # are local. global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) global_indices = gda.get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[global_indices[device]], device) for device in global_mesh.local_devices ] while state: gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
async def create_async_gsda_from_callback( global_shape: gda.Shape, global_mesh: Mesh, mesh_axes: gda.MeshAxes, data_callback: Callable[[gda.Index], asyncio.Future], ): indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes) future_arrays = [ data_callback(indices[d]) for d in global_mesh.local_devices ] # Pause here and come back to `from_async_callback()` when future_arrays are # ready. device_put cannot happen with future_arrays. local_arrays = await asyncio.gather(*future_arrays) dbs = [ jax.device_put(array, device) for array, device in zip(local_arrays, global_mesh.local_devices) ] return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)