def test_pjit(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) @partial(pjit, in_axis_resources=(P('x'), P('x')), out_axis_resources=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.int64).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("threefry_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, ((max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def psum_bind(*args, axis_name, axis_index_groups): if all(not isinstance(x, core.Tracer) for x in args): if axis_index_groups is not None: size = len(axis_index_groups[0]) elif isinstance(axis_name, (list, tuple)): size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore else: size = core.axis_frame(axis_name).size # type: ignore return tuple(size * x for x in args) return core.Primitive.bind( psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
def testCompilationCache(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2,) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.assert_num_jit_and_pmap_compilations(1): sharded_f(x) sharded_f(x)
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: """Test utility for setting up meshes given mesh data from `schedules`.""" # This is similar to the `with_mesh` function above, but isn't a decorator. axis_names, shape = unzip2(named_shape) size = prod(shape) local_devices = list(jax.local_devices()) if len(local_devices) < size: raise SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh(mesh_devices, axis_names): yield
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape: chunk_size = [] for mesh_axis, size in zip(mesh_axes, global_shape): if not mesh_axis: chunk_size.append(size) elif isinstance(mesh_axis, tuple): m = prod([global_mesh.shape[ma] for ma in mesh_axis]) chunk_size.append(size // m) else: chunk_size.append(size // global_mesh.shape[mesh_axis]) if len(chunk_size) != len(global_shape): chunk_size.extend(global_shape[len(chunk_size):]) return tuple(chunk_size)
def _parse_dim(spec): if '+' in spec: return np.sum(map(_parse_dim, spec.split('+'))) elif '*' in spec: return prod(map(_parse_dim, spec.split('*'))) elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit(): return _parse_lit(spec) elif spec[0] in _identifiers: return _parse_id(spec) elif spec == '_': return _monomorphic_dim else: raise ShapeSyntaxError(spec)
def test_gda_block_until_ready(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertTrue(gda.block_until_ready() is gda)
def gda_construction_callback(mesh_axes, state): # Keep the mesh containing 8 local devices as using >8 local devices is # unrealistic. Since `from_callback` measures `device_put` time as well, it # dominates when local devices are for example 2048 (local devices will never # be 2048). global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] while state: gda.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb)
def test_gda_equality_raises_not_implemented(self): global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(None,) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) same_input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesRegex(NotImplementedError, 'GlobalDeviceArray equality is intentionally unimplemented.'): input_gda == same_input_gda
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None, dtype=None): t = ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT).result() shape = t.shape if global_shape is None else global_shape requires_padding = prod(shape) > prod(t.shape) if requires_padding: new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes) async def cb(index): if requires_padding: # This is needed because the shape the array was saved with is smaller # than the requested shape of the array in which it will be reloaded. So # the extra values will be filled with 0s. out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) requested_domain = ts.IndexTransform( input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) await ts.array(out)[ts.d[:].translate_to[requested_domain.origin] ][restricted_domain].write(t[restricted_domain] ) else: out = await t[index].read() if dtype is not None: # Cast while reloading on process to avoid 2 copies on device if the # casting is done on device. return out.astype(dtype) return out return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
def test_async_checkpointing(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) temp_ckpt_dir1 = pathlib.Path( self.create_tempdir('temp_first').full_path) ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first') s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager.wait_until_finished() d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)]) m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) # Will throw `file already exists` error when `tf.io.gfile.rename`. # `wait_until_finished` will raise the error. with self.assertRaises(Exception): ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) manager1 = serialization.GlobalAsyncCheckpointManager() manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager1.wait_until_finished()
def test_gda_str_repr(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(str(gda), 'GlobalDeviceArray(shape=(8, 2), dtype=int32)') self.assertEqual( repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, ' "global_mesh_shape={'x': 4, 'y': 2}, " "mesh_axes=PartitionSpec(('x', 'y'),))"))
def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testBasic1D(self): @partial(pjit, in_axis_resources=(P('x'), P('x')), out_axis_resources=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False)
def gda_construction_raw(mesh_shape, mesh_axes, state): # `device_put` time is not measured in this benchmark. All the devices here # are local. global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) global_indices = gda.get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[global_indices[device]], device) for device in global_mesh.local_devices ] while state: gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def testNestedShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") shape = (8, 8) @jit def f(x): return lax.while_loop(lambda i: i[0,0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) expected = x + 10. actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2)
def testBasic(self): if jax.device_count() < 2: raise SkipTest @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False)
def init(key, shape, dtype=dtype): if len(shape) < 2: raise ValueError( "orthogonal initializer requires at least a 2D shape") n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis] matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) A = random.normal(key, matrix_shape, dtype) Q, R = jnp.linalg.qr(A) diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) Q *= diag_sign # needed for a uniform distribution if n_rows < n_cols: Q = Q.T Q = jnp.reshape( Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis], )) Q = jnp.moveaxis(Q, -1, column_axis) return scale * Q
def testDeviceBufferAval(self): @partial(pjit, in_axis_resources=None, out_axis_resources=P('x')) def f(x): return x shape = (2, 2) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x) expected = x self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 1) self.assertAllClose( actual.device_buffers[0].to_py(), expected, check_dtypes=False) # Repro for a bug on device_buffer aval _ = repr(actual.device_buffers)
def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32): """Compares pmap(sharded_jit(f, ...)) to pmap(f)""" shape = (2, 4, 4) num_shards = shape[0] * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) x = np.arange(prod(shape)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) expected = pmap(f)(x, y) self.assertAllClose(result, expected, check_dtypes=False) flat_result = tree_util.tree_flatten(result)[0] for r in flat_result: self.assertTrue(isinstance(r, pxla.ShardedDeviceArray)) self.assertEqual(len(r.device_buffers), num_shards)
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = create_global_mesh((8,), ('x')) global_input_shape = (16,) global_input_data = np.arange(prod(global_input_shape)).reshape(-1) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids)
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate([ full(1.0, shape=(1, )), full(2.0, shape=(n - 2 + is_odd, )), full(1.0, shape=(1 - is_odd, )) ], dimension=0) scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) return out
def test_pjit_gsda_mesh_mismatch(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(index): return global_input_data[index] gda_obj = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesRegex(ValueError, "Pjit's mesh and GDA's mesh should be equal."): @partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y')) def f(x): return x f(gda_obj)
def test_gda_batched_callback(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(indices): self.assertEqual(len(indices), len(global_mesh.local_devices)) return [global_input_data[index] for index in indices] gda = GlobalDeviceArray.from_batched_callback( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1]]) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[2, 3]]) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value)
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1,2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2,1), out_parts=P(2,1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2,2), out_parts=P(2,2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testGradOfShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=P(4,1), out_parts=None) def f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y) return vjp_f(p) def expected_f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(z), y) return vjp_f(p) shape = (4, 4) x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape) actual = f(x) expected = expected_f(x) self.assertAllClose(actual, expected, check_dtypes=False)
def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.asarray(args).sum() shape = (2, 4, 4) args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)