Ejemplo n.º 1
0
 def from_batched_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
                           mesh_axes: MeshAxes,
                           data_callback: Callable[[Sequence[Index]],
                                                   Sequence[ArrayLike]]):
     indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
     local_indices = [indices[d] for d in global_mesh.local_devices]
     local_arrays = data_callback(local_indices)
     dbs = pxla.device_put(local_arrays, global_mesh.local_devices)
     return cls(global_shape, global_mesh, mesh_axes, dbs)
Ejemplo n.º 2
0
    def from_batched_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
                              mesh_axes: MeshAxes,
                              data_callback: Callable[[Sequence[Index]],
                                                      Sequence[ArrayLike]]):
        """Constructs a GlobalDeviceArray via batched data fetched from ``data_callback``.

    Like ``from_callback``, except the callback function is called only once to fetch all data
    local to this process.

    Example:

      >>> from jax.experimental.maps import Mesh
      >>> from jax.experimental import PartitionSpec as P
      >>> import numpy as np
      ...
      >>> global_input_shape = (8, 2)
      >>> mesh_axes = P('x')
      >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
      >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
      ...
      >>> def batched_cb(indices):
      ...   assert len(indices) == len(global_mesh.local_devices)
      ...   return [global_input_data[index] for index in indices]
      ...
      >>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
      >>> gda.local_data(0).shape
      (2, 2)

    Args:
      global_shape : The global shape of the array
      global_mesh : The global mesh representing devices across multiple
        processes.
      mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
      data_callback : Callback that takes a batch of indices into the global array value with
        length equal to the number of local devices as input and returns the corresponding data for each index.
        The data can be returned as any array-like objects, e.g. ``numpy.ndarray``
"""
        global_indices_rid = get_shard_indices_replica_ids(
            global_shape, global_mesh, mesh_axes)
        local_devices = global_mesh.local_devices
        local_indices = [global_indices_rid[d][0] for d in local_devices]
        local_arrays = data_callback(local_indices)
        dbs = pxla.device_put(local_arrays, local_devices)
        return cls(global_shape,
                   global_mesh,
                   mesh_axes,
                   dbs,
                   _gda_fast_path_args=_GdaFastPathArgs(
                       global_indices_rid, local_devices))