Esempio n. 1
0
    def test_pjit_TwoMeshAxisSharding(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=P(("x", "y"), ),
                           out_axis_resources=P(("x", "y"), ))
        def jax_func(x, y):
            return x @ y

        x_shape = (24, 8)
        y_shape = (8, 2)
        x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
        y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
        self._check_sharding_annotations(
            jax_func,
            [x, y],
            expected=[
                r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3",  # x
                r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3",  # y
                r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3",  # output
            ],
            expected_opt=[
                # TODO: relax ordering
                r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3",
                r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3",
                # TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
                r"f32\[1,6,2\]",  # output
            ],
            num_partitions=4)
Esempio n. 2
0
    def test_pjit_basic2D(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=(P(None, "x", "y"), P("y")),
                           out_axis_resources=P("x"))
        def jax_func(x, y):
            return x @ y

        x_shape = (8, 6, 4)
        y_shape = (4, 2)
        x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
        y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
        self._check_sharding_annotations(
            jax_func,
            [x, y],
            expected=[
                r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3",  # x
                r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate",  # y
                r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate",  # output
            ],
            expected_opt=[
                # TODO: relax ordering
                r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate|f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3",
                r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate|f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3",
                # TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
                r"bf16\[4,6,2\]",  # output
            ],
            num_partitions=4)
Esempio n. 3
0
    def test_pjit_basic1D(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=(P("x"), P("x")),
                           out_axis_resources=None)
        def jax_func(x, y):
            return x + y

        shape = (8, 10)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
        print(f"HLO is {hlo}")
        print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
        self._check_sharding_annotations(
            jax_func,
            [x, x],
            expected=[
                r"f32\[8,10\].*sharding={devices=\[2,1\]",  # x and y
                r"f32\[8,10\].*sharding={replicated",  # output
            ],
            expected_opt=[
                r"f32\[4,10\].*sharding={devices=\[2,1\]",  # x and y
                # TODO: why don't we see "sharding={replicated"
                r"f32\[8,10\]",  # output
            ],
            num_partitions=2)
Esempio n. 4
0
  def _pjit(inp):
    if isinstance(inp, GlobalDeviceArray):
      if inp.is_fully_replicated:
        return inp.local_data(0).to_py()
      global_mesh = inp.mesh
      in_axis_resources = FROM_GDA
    else:
      # DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
      # Shape of local_mesh will always be (1, local_device_count())
      devices = np.array(jax.devices()).reshape(jax.process_count(),
                                                jax.local_device_count())
      global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
      in_axis_resources = P('processes')
      if inp.ndim == 0 or not tiled:
        inp = np.expand_dims(inp, axis=0)

    with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
      out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
                 out_axis_resources=None)(inp)
    return out.local_data(0).to_py()
Esempio n. 5
0
 def jax_func(x):  # x: f32[12, 8]
     y = jnp.tile(x, (2, 1))  # y: f32[24, 8]
     y = pjit.with_sharding_constraint(y, P("x", "y"))
     return y[0:y.shape[0] // 4]  # res: f32[6, 8]
Esempio n. 6
0
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX `api` functions."""

from functools import partial

import google_benchmark
import jax
from jax._src import test_util as jtu
from jax._src.util import prod
from jax.interpreters.pxla import PartitionSpec as P
from jax.experimental import global_device_array as gda
import numpy as np

mesh_shapes_axes = [
    ((256, 8), P("x", "y")),
    ((256, 8), P(None)),
    ((256, 8), P("x")),
    ((256, 8), P("y")),
    ((256, 8), P(("x", "y"))),
    ((128, 8), P("x", "y")),
    ((4, 2), P("x", "y")),
    ((16, 4), P("x", "y")),
    ((16, 4), P(("x", "y"))),
]


def gda_construction_callback(mesh_axes, state):
    # Keep the mesh containing 8 local devices as using >8 local devices is
    # unrealistic. Since `from_callback` measures `device_put` time as well, it
    # dominates when local devices are for example 2048 (local devices will never