Example #1
0
 def _assert_contiguous_submeshes(self, global_device_mesh):
   global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim)))
   max_process_index = max(d.process_index
                           for d in global_device_mesh.flatten())
   for p_idx in range(max_process_index + 1):
     # Raises an error if non-contiguous
     global_mesh._local_mesh(p_idx)
    def test_device_mismatch(self):
        devices = jax.devices()
        if len(devices) < 8:
            raise unittest.SkipTest("Test requires 8 global devices.")
        mesh_devices = np.array([[devices[0], devices[2]],
                                 [devices[3], devices[1]],
                                 [devices[4], devices[6]],
                                 [devices[7], devices[5]]])
        global_mesh = Mesh(mesh_devices, ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x', 'y']
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)
        indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)

        dbs = [
            jax.device_put(global_input_data[indices[d]], d)
            for d in jax.local_devices()
        ]

        with self.assertRaisesRegex(
                ValueError,
                'The `global_mesh.local_devices` and `device_buffers` device order'
        ):
            GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
Example #3
0
def create_global_mesh(mesh_shape, axis_names):
  size = prod(mesh_shape)
  if len(jax.devices()) < size:
    raise unittest.SkipTest(f"Test requires {size} local devices")
  mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
  global_mesh = Mesh(mesh_devices, axis_names)
  return global_mesh
Example #4
0
def create_global_mesh(mesh_shape, axis_names):
    size = prod(mesh_shape)
    if len(api.devices()) < size:
        raise unittest.SkipTest(f"Test requires {size} global devices.")
    devices = sorted(api.devices(), key=lambda d: d.id)
    mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
    global_mesh = Mesh(mesh_devices, axis_names)
    return global_mesh
    def run(self):
        print(f"jax runtime initialization starting")
        import jax
        from jax.experimental.maps import thread_resources, ResourceEnv, Mesh
        import haiku as hk
        from mesh_transformer.checkpoint import write_ckpt, read_ckpt
        from mesh_transformer.transformer_shard import CausalTransformer
        # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True

        thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object),
                                                ()))

        start = time.time()
        # print(jax.devices())
        print(f"jax devices: {jax.device_count()}")
        print(f"jax runtime initialized in {time.time() - start:.06}s")
        devices = np.array(jax.devices()).reshape(self.mesh_shape)

        with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
            start = time.time()
            network: CausalTransformer = self.network_builder()
            param_count = hk.data_structures.tree_size(network.state['params'])
            print(f"Initialized in {time.time() - start:.06}s")
            print(f"Total parameters: {param_count}")

            while True:
                operation, input = self.input_q.get()
                if operation == "train":
                    self.output_q.put(network.train(input))
                elif operation == "eval":
                    self.output_q.put(network.eval(input))
                elif operation == "generate":
                    self.output_q.put(network.generate(*input))
                elif operation == "write_ckpt":
                    path, shard = input
                    write_ckpt(network.state, path, shard)
                    self.output_q.put(None)
                elif operation == "load_ckpt":
                    network.state = read_ckpt(network.state, input,
                                              devices.shape[1])
                    self.output_q.put(network.state["step"][0])
                elif operation == "get_params":
                    self.output_q.put(
                        hk.data_structures.tree_size(network.state['params']))
                elif operation == "move_params":
                    # only needed for inference, otherwise first train step does this
                    local_shards = max(
                        jax.local_device_count() // self.mesh_shape[1], 1)

                    # delete the optimizer states otherwise it OOMs for some reason
                    # TODO: use ShardedDeviceArray or something to get around this for bigger models
                    del network.state["opt_state"]
                    network.state = network.move_xmap(network.state,
                                                      np.zeros(local_shards))
                    self.output_q.put(None)
                else:
                    raise Exception("Not implemented")
Example #6
0
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
    """Test utility for setting up meshes given mesh data from `schedules`."""
    # This is similar to the `with_mesh` function above, but isn't a decorator.
    axis_names, shape = unzip2(named_shape)
    size = prod(shape)
    local_devices = list(api.local_devices())
    if len(local_devices) < size:
        raise unittest.SkipTest(f"Test requires {size} local devices")
    mesh_devices = np.array(local_devices[:size]).reshape(shape)
    with Mesh(mesh_devices, axis_names):
        yield