def _assert_contiguous_submeshes(self, global_device_mesh): global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim))) max_process_index = max(d.process_index for d in global_device_mesh.flatten()) for p_idx in range(max_process_index + 1): # Raises an error if non-contiguous global_mesh._local_mesh(p_idx)
def test_device_mismatch(self): devices = jax.devices() if len(devices) < 8: raise unittest.SkipTest("Test requires 8 global devices.") mesh_devices = np.array([[devices[0], devices[2]], [devices[3], devices[1]], [devices[4], devices[6]], [devices[7], devices[5]]]) global_mesh = Mesh(mesh_devices, ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[indices[d]], d) for d in jax.local_devices() ] with self.assertRaisesRegex( ValueError, 'The `global_mesh.local_devices` and `device_buffers` device order' ): GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def create_global_mesh(mesh_shape, axis_names): size = prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape) global_mesh = Mesh(mesh_devices, axis_names) return global_mesh
def create_global_mesh(mesh_shape, axis_names): size = prod(mesh_shape) if len(api.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(api.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) global_mesh = Mesh(mesh_devices, axis_names) return global_mesh
def run(self): print(f"jax runtime initialization starting") import jax from jax.experimental.maps import thread_resources, ResourceEnv, Mesh import haiku as hk from mesh_transformer.checkpoint import write_ckpt, read_ckpt from mesh_transformer.transformer_shard import CausalTransformer # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ())) start = time.time() # print(jax.devices()) print(f"jax devices: {jax.device_count()}") print(f"jax runtime initialized in {time.time() - start:.06}s") devices = np.array(jax.devices()).reshape(self.mesh_shape) with jax.experimental.maps.mesh(devices, ('dp', 'mp')): start = time.time() network: CausalTransformer = self.network_builder() param_count = hk.data_structures.tree_size(network.state['params']) print(f"Initialized in {time.time() - start:.06}s") print(f"Total parameters: {param_count}") while True: operation, input = self.input_q.get() if operation == "train": self.output_q.put(network.train(input)) elif operation == "eval": self.output_q.put(network.eval(input)) elif operation == "generate": self.output_q.put(network.generate(*input)) elif operation == "write_ckpt": path, shard = input write_ckpt(network.state, path, shard) self.output_q.put(None) elif operation == "load_ckpt": network.state = read_ckpt(network.state, input, devices.shape[1]) self.output_q.put(network.state["step"][0]) elif operation == "get_params": self.output_q.put( hk.data_structures.tree_size(network.state['params'])) elif operation == "move_params": # only needed for inference, otherwise first train step does this local_shards = max( jax.local_device_count() // self.mesh_shape[1], 1) # delete the optimizer states otherwise it OOMs for some reason # TODO: use ShardedDeviceArray or something to get around this for bigger models del network.state["opt_state"] network.state = network.move_xmap(network.state, np.zeros(local_shards)) self.output_q.put(None) else: raise Exception("Not implemented")
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: """Test utility for setting up meshes given mesh data from `schedules`.""" # This is similar to the `with_mesh` function above, but isn't a decorator. axis_names, shape = unzip2(named_shape) size = prod(shape) local_devices = list(api.local_devices()) if len(local_devices) < size: raise unittest.SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with Mesh(mesh_devices, axis_names): yield