def gda_construction_raw(mesh_shape, mesh_axes, state): # `device_put` time is not measured in this benchmark. All the devices here # are local. global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) global_indices = gda.get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[global_indices[device]], device) for device in global_mesh.local_devices ] while state: gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
async def create_async_gsda_from_callback( global_shape: gda.Shape, global_mesh: Mesh, mesh_axes: gda.MeshAxes, data_callback: Callable[[gda.Index], asyncio.Future], ): indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes) future_arrays = [ data_callback(indices[d]) for d in global_mesh.local_devices ] # Pause here and come back to `from_async_callback()` when future_arrays are # ready. device_put cannot happen with future_arrays. local_arrays = await asyncio.gather(*future_arrays) dbs = [ jax.device_put(array, device) for array, device in zip(local_arrays, global_mesh.local_devices) ] return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)
async def create_async_gda_from_callback( global_shape: gda.Shape, global_mesh: Mesh, mesh_axes: gda.MeshAxes, data_callback: Callable[[gda.Index], asyncio.Future], ): global_idx_rid = gda.get_shard_indices_replica_ids( global_shape, global_mesh, mesh_axes) local_devices = global_mesh.local_devices future_arrays = [data_callback(global_idx_rid[d][0]) for d in local_devices] # Pause here and come back to `from_async_callback()` when future_arrays are # ready. device_put cannot happen with future_arrays. local_arrays = await asyncio.gather(*future_arrays) dbs = [jax.device_put(array, device) for array, device in zip(local_arrays, local_devices)] return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs, gda._GdaFastPathArgs(global_idx_rid, local_devices))
def _gda(global_shape, pspec, dbs): return gda_lib.GlobalDeviceArray(global_shape.shape, global_mesh, pspec, dbs)