コード例 #1
0
    def test_device_mismatch(self):
        devices = jax.devices()
        if len(devices) < 8:
            raise unittest.SkipTest("Test requires 8 global devices.")
        mesh_devices = np.array([[devices[0], devices[2]],
                                 [devices[3], devices[1]],
                                 [devices[4], devices[6]],
                                 [devices[7], devices[5]]])
        global_mesh = Mesh(mesh_devices, ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x', 'y']
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)
        indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)

        dbs = [
            jax.device_put(global_input_data[indices[d]], d)
            for d in jax.local_devices()
        ]

        with self.assertRaisesRegex(
                ValueError,
                'The `global_mesh.local_devices` and `device_buffers` device order'
        ):
            GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
コード例 #2
0
  def devices_indices_map(
      self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
    # TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
    from jax.experimental import global_device_array

    # `get_shard_indices` is cached.
    return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)
コード例 #3
0
def gda_construction_raw(mesh_shape, mesh_axes, state):
    # `device_put` time is not measured in this benchmark. All the devices here
    # are local.
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
    global_input_shape = (2048, 2048)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
                                           mesh_axes)
    dbs = [
        jax.device_put(global_input_data[global_indices[device]], device)
        for device in global_mesh.local_devices
    ]

    while state:
        gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
コード例 #4
0
ファイル: serialization.py プロジェクト: ahoenselaar/jax
async def create_async_gsda_from_callback(
    global_shape: gda.Shape,
    global_mesh: Mesh,
    mesh_axes: gda.MeshAxes,
    data_callback: Callable[[gda.Index], asyncio.Future],
):
    indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes)
    future_arrays = [
        data_callback(indices[d]) for d in global_mesh.local_devices
    ]
    # Pause here and come back to `from_async_callback()` when future_arrays are
    # ready. device_put cannot happen with future_arrays.
    local_arrays = await asyncio.gather(*future_arrays)

    dbs = [
        jax.device_put(array, device)
        for array, device in zip(local_arrays, global_mesh.local_devices)
    ]
    return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)