示例#1
0
def execute_compiled(compiled, partitioner, handlers, dim_vals, args):
  input_bufs = list(it.chain(
      (buf for x in dim_vals for buf in xla.device_put(x, None)),
      (buf for x in args     for buf in xla.device_put(x, None))))
  out_bufs = compiled.execute(input_bufs)
  dims_dict, grouped_bufs = partitioner(out_bufs)
  return [handler(dims_dict, bs) for handler, bs in zip(handlers, grouped_bufs)]
示例#2
0
 def _prefetch(xs):
     if hasattr(jax.api, "device_put_sharded"):  # jax>=0.2.0
         return jax.api.device_put_sharded(list(xs), devices)
     else:
         aval = jax.xla.abstractify(xs)
         assert xs.shape[0] == len(devices), (
             "The first dimension of the iterator's ndarrays is not "
             "equal to the number of devices.")
         buffers = [xla.device_put(x, devices[i]) for i, x in enumerate(xs)]
         return jax.pxla.ShardedDeviceArray(aval, buffers)
示例#3
0
    def testReshardInput(self):
        if xla_bridge.device_count() < 6:
            raise SkipTest("testReshardInput requires 6 devices")
        # Manually construct a ShardedDeviceArray with the wrong sharding for the
        # subsequent pmap
        shard_shape = (3, 2)
        shard = np.arange(np.prod(shard_shape)).reshape(shard_shape)
        bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
        aval = ShapedArray((6, 4), shard.dtype)
        sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2),
                                          is_axis_materialized=(True, True),
                                          replication_factor=2)
        arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)

        r = pmap(lambda x: x + 1)(arr)
        self.assertAllClose(r, arr + 1, check_dtypes=True)
        self.assertEqual(len(r.device_buffers), 6)
示例#4
0
def _replicate(x, devices=None):
  x = jax.numpy.asarray(x)
  if devices is None:
    # match the default device assignments used in pmap:
    # for single-host, that's the XLA default device assignment
    # for multi-host, it's the order of jax.local_devices()
    if jax.host_count() == 1:
      devices = [d for d in xb.get_backend().get_default_device_assignment(
          jax.device_count()) if d.host_id == jax.host_id()]
    else:
      devices = jax.local_devices()
  if hasattr(jax.api, "device_put_sharded"):  # jax >= 0.2.0
    return jax.api.device_put_sharded(len(devices) * [x], devices)
  else:
    aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
    buffers = [xla.device_put(x, device=d) for d in devices]
    return jax.pxla.ShardedDeviceArray(aval, buffers)