예제 #1
0
def pmap_shard_device_array_benchmark():
  """Pmap benchmark focusing on shard_args DeviceArray path.

  This is intended to measure how long it takes to dispatch a DeviceArray to
  pmap.
  """

  def get_benchmark_fn(nargs, nshards):
    pmap_fn = pmap(lambda *args: np.sum(args))
    shape = (nshards, 4)
    args = [np.array(onp.random.random(shape)) for _ in range(nargs)]
    assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
    def benchmark_fn():
      for _ in range(10):
        pmap_fn(*args)
    return benchmark_fn

  params = []
  for nargs in (10, 100, 500):
    nshards = min(8, jax.local_device_count())
    params.append({"nargs": nargs, "nshards": nshards})
  for nshards in (2, 4, 8):
    if nshards > jax.local_device_count(): continue
    params.append({"nargs": 100, "nshards": nshards})
  benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array")
예제 #2
0
def pmap_shard_outputs_benchmark():
    """Pmap benchmark focusing on array_result_handler path.

  This is intended to measure how long it takes to construct ShardedDeviceArrays
  from pmap.
  """
    def get_benchmark_fn(nouts, nshards):
        pmap_fn = pmap(lambda x: [x + i for i in range(nouts)])
        shape = (nshards, 4)
        arg = np.random.random(shape)

        def benchmark_fn():
            for _ in range(100):
                pmap_fn(arg)

        return benchmark_fn

    params = []
    for nouts in (10, 100, 500, 1000, 5000):
        nshards = min(8, jax.local_device_count())
        params.append({"nouts": nouts, "nshards": nshards})
    for nshards in (2, 4, 8, 100, 500):
        if nshards > jax.local_device_count(): continue
        params.append({"nouts": 100, "nshards": nshards})
    benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
예제 #3
0
def sharded_device_array_indexing_benchmark():
  """Benchmark focusing on ShardedDeviceArray indexing."""
  def get_benchmark_fn(indices_fn):
    nshards = min(8, jax.local_device_count())
    shape = (nshards, 8, 8)
    def benchmark_fn():
      arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape))
      indices = indices_fn()
      for idx in indices:
        arr[idx]
    return benchmark_fn

  num_internal_iters = 1000

  def integer_indices():
    return (i for _ in range(num_internal_iters) for i in range(8))

  def integer_2D_indices():
    return ((i,i) for _ in range(num_internal_iters) for i in range(8))

  params = []
  params.append({"indices_fn": integer_indices})
  params.append({"indices_fn": integer_2D_indices})
  benchmark.benchmark_suite(get_benchmark_fn, params,
                            "ShardedDeviceArray_indexing")
예제 #4
0
def pmap_shard_args_benchmark():
    """Pmap benchmark focusing on shard_args fast path.

  This is intended to measure how long it takes to dispatch a correctly-sharded
  ShardedDeviceArray to pmap.
  """
    def get_benchmark_fn(nargs, nshards):
        pmap_fn = pmap(lambda *args: np.sum(args))
        shape = (nshards, 4)
        args = [onp.random.random(shape) for _ in range(nargs)]
        sharded_args = pmap(lambda x: x)(args)
        assert all(
            isinstance(arg, jax.pxla.ShardedDeviceArray)
            for arg in sharded_args)

        def benchmark_fn():
            for _ in range(100):
                pmap_fn(*sharded_args)

        return benchmark_fn

    params = []
    for nargs in (10, 100, 101, 500):
        nshards = min(4, jax.local_device_count())
        params.append({"nargs": nargs, "nshards": nshards})
    for nshards in (2, 4, 8, 100, 500):
        if nshards > jax.local_device_count(): continue
        params.append({"nargs": 10, "nshards": nshards})
    benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")