Beispiel #1
0
def make_array_from_callback(shape: Shape, sharding: Sharding,
                             data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
  arrays = [
      device_put(data_callback(sharding.device_indices(device, shape)), device)
      for device in sharding.addressable_devices
  ]
  return Array(shape, sharding, arrays, committed=True)
Beispiel #2
0
    def test_sum(self):
        # https://github.com/google/jax/issues/2905
        cpus = api.devices("cpu")

        x = api.device_put(np.ones(2), cpus[0])
        y = x.sum()
        self.assertEqual(y.device_buffer.device(), cpus[0])
Beispiel #3
0
  def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
                    mesh_axes: MeshAxes, data_callback: Callable[[Index],
                                                                 ArrayLike]):
    """Constructs a GlobalDeviceArray via data fetched from ``data_callback``.

    ``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray.

    Example::

      global_input_shape = (8, 2)
      global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
      def cb(index):
        return global_input_data[index]
      gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)

    Args:
      global_shape : The global shape of the array
      global_mesh : The global mesh representing devices across multiple
        processes.
      mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
      data_callback : Callback that takes indices into the global array value as input and
        returns the corresponding data of the global array value.  The data can be returned
        as any array-like object, e.g. a ``numpy.ndarray``.
    """
    global_indices_rid = get_shard_indices_replica_ids(
        global_shape, global_mesh, mesh_axes)
    local_devices = global_mesh.local_devices
    dbs = [
        device_put(data_callback(global_indices_rid[device][0]), device)
        for device in local_devices
    ]
    return cls(global_shape, global_mesh, mesh_axes, dbs,
               _gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
Beispiel #4
0
    def test_jit_on_nondefault_backend(self):
        cpus = api.devices("cpu")
        self.assertNotEmpty(cpus)

        # Since we are not on CPU, some other backend will be the default
        default_dev = api.devices()[0]
        self.assertNotEqual(default_dev.platform, "cpu")

        data_on_cpu = api.device_put(1, device=cpus[0])
        self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])

        def my_sin(x):
            return jnp.sin(x)

        # jit without any device spec follows the data
        result1 = api.jit(my_sin)(2)
        self.assertEqual(result1.device_buffer.device(), default_dev)
        result2 = api.jit(my_sin)(data_on_cpu)
        self.assertEqual(result2.device_buffer.device(), cpus[0])

        # jit with `device` spec places the data on the specified device
        result3 = api.jit(my_sin, device=cpus[0])(2)
        self.assertEqual(result3.device_buffer.device(), cpus[0])

        # jit with `backend` spec places the data on the specified backend
        result4 = api.jit(my_sin, backend="cpu")(2)
        self.assertEqual(result4.device_buffer.device(), cpus[0])
Beispiel #5
0
 def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
                   mesh_axes: MeshAxes, data_callback: Callable[[Index],
                                                                ArrayLike]):
     indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
     dbs = [
         device_put(data_callback(indices[device]), device)
         for device in global_mesh.local_devices
     ]
     return cls(global_shape, global_mesh, mesh_axes, dbs)
Beispiel #6
0
 def testBooleanIndexingArray2D(self):
   idx = np.array([[True, False],
                    [False, True],
                    [False, False],
                    [True, True]])
   x = np.arange(8).reshape(4, 2)
   ans = api.device_put(x)[idx]
   expected = x[idx]
   self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #7
0
    def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes, data_callback: Callable[[Index],
                                                                   ArrayLike]):
        """Constructs a GlobalDeviceArray via data fetched from ``data_callback``.

    ``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray.

    Example:

      >>> from jax.experimental.maps import Mesh
      >>> from jax.experimental import PartitionSpec as P
      >>> import numpy as np
      ...
      >>> global_input_shape = (8, 8)
      >>> mesh_axes = P('x', 'y')
      >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
      >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
      ...
      >>> def cb(index):
      ...  return global_input_data[index]
      ...
      >>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
      >>> gda.local_data(0).shape
      (4, 2)

    Args:
      global_shape : The global shape of the array
      global_mesh : The global mesh representing devices across multiple
        processes.
      mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
      data_callback : Callback that takes indices into the global array value as input and
        returns the corresponding data of the global array value.  The data can be returned
        as any array-like object, e.g. a ``numpy.ndarray``.
    """
        global_indices_rid = get_shard_indices_replica_ids(
            global_shape, global_mesh, mesh_axes)
        local_devices = global_mesh.local_devices
        dbs = [
            device_put(data_callback(global_indices_rid[device][0]), device)
            for device in local_devices
        ]
        return cls(global_shape,
                   global_mesh,
                   mesh_axes,
                   dbs,
                   _gda_fast_path_args=_GdaFastPathArgs(
                       global_indices_rid, local_devices))
Beispiel #8
0
def _asarray(a):
    # simplified version of jnp.asarray() for local use.
    return a if isinstance(a, ndarray) else api.device_put(a)
Beispiel #9
0
 def testBooleanIndexingList2DBroadcast(self):
   idx = [True, True, False, True]
   x = np.arange(8).reshape(4, 2)
   with self.assertRaisesRegex(TypeError, ARRAY_MSG):
     api.device_put(x)[idx]
Beispiel #10
0
 def testBooleanIndexingList1D(self):
   idx = [True, True, False]
   x = api.device_put(np.arange(3))
   with self.assertRaisesRegex(TypeError, ARRAY_MSG):
     x[idx]
Beispiel #11
0
 def testBooleanIndexingArray1D(self):
   idx = np.array([True, True, False])
   x = api.device_put(np.arange(3))
   ans = x[idx]
   expected = np.arange(3)[idx]
   self.assertAllClose(ans, expected, check_dtypes=False)