def pmap_shard_device_array_benchmark(): """Pmap benchmark focusing on shard_args DeviceArray path. This is intended to measure how long it takes to dispatch a DeviceArray to pmap. """ def get_benchmark_fn(nargs, nshards): pmap_fn = pmap(lambda *args: np.sum(args)) shape = (nshards, 4) args = [np.array(onp.random.random(shape)) for _ in range(nargs)] assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args) def benchmark_fn(): for _ in range(10): pmap_fn(*args) return benchmark_fn params = [] for nargs in (10, 100, 500): nshards = min(8, jax.local_device_count()) params.append({"nargs": nargs, "nshards": nshards}) for nshards in (2, 4, 8): if nshards > jax.local_device_count(): continue params.append({"nargs": 100, "nshards": nshards}) benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array")
def pmap_shard_outputs_benchmark(): """Pmap benchmark focusing on array_result_handler path. This is intended to measure how long it takes to construct ShardedDeviceArrays from pmap. """ def get_benchmark_fn(nouts, nshards): pmap_fn = pmap(lambda x: [x + i for i in range(nouts)]) shape = (nshards, 4) arg = np.random.random(shape) def benchmark_fn(): for _ in range(100): pmap_fn(arg) return benchmark_fn params = [] for nouts in (10, 100, 500, 1000, 5000): nshards = min(8, jax.local_device_count()) params.append({"nouts": nouts, "nshards": nshards}) for nshards in (2, 4, 8, 100, 500): if nshards > jax.local_device_count(): continue params.append({"nouts": 100, "nshards": nshards}) benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
def sharded_device_array_indexing_benchmark(): """Benchmark focusing on ShardedDeviceArray indexing.""" def get_benchmark_fn(indices_fn): nshards = min(8, jax.local_device_count()) shape = (nshards, 8, 8) def benchmark_fn(): arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape)) indices = indices_fn() for idx in indices: arr[idx] return benchmark_fn num_internal_iters = 1000 def integer_indices(): return (i for _ in range(num_internal_iters) for i in range(8)) def integer_2D_indices(): return ((i,i) for _ in range(num_internal_iters) for i in range(8)) params = [] params.append({"indices_fn": integer_indices}) params.append({"indices_fn": integer_2D_indices}) benchmark.benchmark_suite(get_benchmark_fn, params, "ShardedDeviceArray_indexing")
def pmap_shard_args_benchmark(): """Pmap benchmark focusing on shard_args fast path. This is intended to measure how long it takes to dispatch a correctly-sharded ShardedDeviceArray to pmap. """ def get_benchmark_fn(nargs, nshards): pmap_fn = pmap(lambda *args: np.sum(args)) shape = (nshards, 4) args = [onp.random.random(shape) for _ in range(nargs)] sharded_args = pmap(lambda x: x)(args) assert all( isinstance(arg, jax.pxla.ShardedDeviceArray) for arg in sharded_args) def benchmark_fn(): for _ in range(100): pmap_fn(*sharded_args) return benchmark_fn params = [] for nargs in (10, 100, 101, 500): nshards = min(4, jax.local_device_count()) params.append({"nargs": nargs, "nshards": nshards}) for nshards in (2, 4, 8, 100, 500): if nshards > jax.local_device_count(): continue params.append({"nargs": 10, "nshards": nshards}) benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")