def execute_compiled(compiled, partitioner, handlers, dim_vals, args): input_bufs = list(it.chain( (buf for x in dim_vals for buf in xla.device_put(x, None)), (buf for x in args for buf in xla.device_put(x, None)))) out_bufs = compiled.execute(input_bufs) dims_dict, grouped_bufs = partitioner(out_bufs) return [handler(dims_dict, bs) for handler, bs in zip(handlers, grouped_bufs)]
def _prefetch(xs): if hasattr(jax.api, "device_put_sharded"): # jax>=0.2.0 return jax.api.device_put_sharded(list(xs), devices) else: aval = jax.xla.abstractify(xs) assert xs.shape[0] == len(devices), ( "The first dimension of the iterator's ndarrays is not " "equal to the number of devices.") buffers = [xla.device_put(x, devices[i]) for i, x in enumerate(xs)] return jax.pxla.ShardedDeviceArray(aval, buffers)
def testReshardInput(self): if xla_bridge.device_count() < 6: raise SkipTest("testReshardInput requires 6 devices") # Manually construct a ShardedDeviceArray with the wrong sharding for the # subsequent pmap shard_shape = (3, 2) shard = np.arange(np.prod(shard_shape)).reshape(shard_shape) bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]] aval = ShapedArray((6, 4), shard.dtype) sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2), is_axis_materialized=(True, True), replication_factor=2) arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs) r = pmap(lambda x: x + 1)(arr) self.assertAllClose(r, arr + 1, check_dtypes=True) self.assertEqual(len(r.device_buffers), 6)
def _replicate(x, devices=None): x = jax.numpy.asarray(x) if devices is None: # match the default device assignments used in pmap: # for single-host, that's the XLA default device assignment # for multi-host, it's the order of jax.local_devices() if jax.host_count() == 1: devices = [d for d in xb.get_backend().get_default_device_assignment( jax.device_count()) if d.host_id == jax.host_id()] else: devices = jax.local_devices() if hasattr(jax.api, "device_put_sharded"): # jax >= 0.2.0 return jax.api.device_put_sharded(len(devices) * [x], devices) else: aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) buffers = [xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)