示例#1
0
def gda_construction_raw(mesh_shape, mesh_axes, state):
    # `device_put` time is not measured in this benchmark. All the devices here
    # are local.
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
    global_input_shape = (2048, 2048)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
                                           mesh_axes)
    dbs = [
        jax.device_put(global_input_data[global_indices[device]], device)
        for device in global_mesh.local_devices
    ]

    while state:
        gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
示例#2
0
async def create_async_gsda_from_callback(
    global_shape: gda.Shape,
    global_mesh: Mesh,
    mesh_axes: gda.MeshAxes,
    data_callback: Callable[[gda.Index], asyncio.Future],
):
    indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes)
    future_arrays = [
        data_callback(indices[d]) for d in global_mesh.local_devices
    ]
    # Pause here and come back to `from_async_callback()` when future_arrays are
    # ready. device_put cannot happen with future_arrays.
    local_arrays = await asyncio.gather(*future_arrays)

    dbs = [
        jax.device_put(array, device)
        for array, device in zip(local_arrays, global_mesh.local_devices)
    ]
    return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)
示例#3
0
async def create_async_gda_from_callback(
    global_shape: gda.Shape,
    global_mesh: Mesh,
    mesh_axes: gda.MeshAxes,
    data_callback: Callable[[gda.Index], asyncio.Future],
):
  global_idx_rid = gda.get_shard_indices_replica_ids(
      global_shape, global_mesh, mesh_axes)
  local_devices = global_mesh.local_devices
  future_arrays = [data_callback(global_idx_rid[d][0])
                   for d in local_devices]
  # Pause here and come back to `from_async_callback()` when future_arrays are
  # ready. device_put cannot happen with future_arrays.
  local_arrays = await asyncio.gather(*future_arrays)

  dbs = [jax.device_put(array, device)
         for array, device in zip(local_arrays, local_devices)]
  return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs,
                               gda._GdaFastPathArgs(global_idx_rid, local_devices))
示例#4
0
 def _gda(global_shape, pspec, dbs):
     return gda_lib.GlobalDeviceArray(global_shape.shape, global_mesh,
                                      pspec, dbs)