def test_pjit_TwoMeshAxisSharding(self): @functools.partial(pjit.pjit, in_axis_resources=P(("x", "y"), ), out_axis_resources=P(("x", "y"), )) def jax_func(x, y): return x @ y x_shape = (24, 8) y_shape = (8, 2) x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) self._check_sharding_annotations( jax_func, [x, y], expected=[ r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output ], expected_opt=[ # TODO: relax ordering r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3 r"f32\[1,6,2\]", # output ], num_partitions=4)
def test_pjit_basic2D(self): @functools.partial(pjit.pjit, in_axis_resources=(P(None, "x", "y"), P("y")), out_axis_resources=P("x")) def jax_func(x, y): return x @ y x_shape = (8, 6, 4) y_shape = (4, 2) x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) self._check_sharding_annotations( jax_func, [x, y], expected=[ r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output ], expected_opt=[ # TODO: relax ordering r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate|f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate|f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate? r"bf16\[4,6,2\]", # output ], num_partitions=4)
def test_pjit_basic1D(self): @functools.partial(pjit.pjit, in_axis_resources=(P("x"), P("x")), out_axis_resources=None) def jax_func(x, y): return x + y shape = (8, 10) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text() print(f"HLO is {hlo}") print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}") self._check_sharding_annotations( jax_func, [x, x], expected=[ r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y r"f32\[8,10\].*sharding={replicated", # output ], expected_opt=[ r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y # TODO: why don't we see "sharding={replicated" r"f32\[8,10\]", # output ], num_partitions=2)
def _pjit(inp): if isinstance(inp, GlobalDeviceArray): if inp.is_fully_replicated: return inp.local_data(0).to_py() global_mesh = inp.mesh in_axis_resources = FROM_GDA else: # DA/SDA/np.array will be sharded based on global_mesh.local_mesh. # Shape of local_mesh will always be (1, local_device_count()) devices = np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()) global_mesh = maps.Mesh(devices, ('processes', 'local_devices')) in_axis_resources = P('processes') if inp.ndim == 0 or not tiled: inp = np.expand_dims(inp, axis=0) with maps.Mesh(global_mesh.devices, global_mesh.axis_names): out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out_axis_resources=None)(inp) return out.local_data(0).to_py()
def jax_func(x): # x: f32[12, 8] y = jnp.tile(x, (2, 1)) # y: f32[24, 8] y = pjit.with_sharding_constraint(y, P("x", "y")) return y[0:y.shape[0] // 4] # res: f32[6, 8]
# See the License for the specific language governing permissions and # limitations under the License. """Microbenchmarks for JAX `api` functions.""" from functools import partial import google_benchmark import jax from jax._src import test_util as jtu from jax._src.util import prod from jax.interpreters.pxla import PartitionSpec as P from jax.experimental import global_device_array as gda import numpy as np mesh_shapes_axes = [ ((256, 8), P("x", "y")), ((256, 8), P(None)), ((256, 8), P("x")), ((256, 8), P("y")), ((256, 8), P(("x", "y"))), ((128, 8), P("x", "y")), ((4, 2), P("x", "y")), ((16, 4), P("x", "y")), ((16, 4), P(("x", "y"))), ] def gda_construction_callback(mesh_axes, state): # Keep the mesh containing 8 local devices as using >8 local devices is # unrealistic. Since `from_callback` measures `device_put` time as well, it # dominates when local devices are for example 2048 (local devices will never