def run(self):
        print(f"jax runtime initialization starting")
        import jax
        from jax.experimental.maps import thread_resources, ResourceEnv, Mesh
        import haiku as hk
        from mesh_transformer.checkpoint import write_ckpt, read_ckpt
        from mesh_transformer.transformer_shard import CausalTransformer
        # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True

        thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object),
                                                ()))

        start = time.time()
        # print(jax.devices())
        print(f"jax devices: {jax.device_count()}")
        print(f"jax runtime initialized in {time.time() - start:.06}s")
        devices = np.array(jax.devices()).reshape(self.mesh_shape)

        with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
            start = time.time()
            network: CausalTransformer = self.network_builder()
            param_count = hk.data_structures.tree_size(network.state['params'])
            print(f"Initialized in {time.time() - start:.06}s")
            print(f"Total parameters: {param_count}")

            while True:
                operation, input = self.input_q.get()
                if operation == "train":
                    self.output_q.put(network.train(input))
                elif operation == "eval":
                    self.output_q.put(network.eval(input))
                elif operation == "generate":
                    self.output_q.put(network.generate(*input))
                elif operation == "write_ckpt":
                    path, shard = input
                    write_ckpt(network.state, path, shard)
                    self.output_q.put(None)
                elif operation == "load_ckpt":
                    network.state = read_ckpt(network.state, input,
                                              devices.shape[1])
                    self.output_q.put(network.state["step"][0])
                elif operation == "get_params":
                    self.output_q.put(
                        hk.data_structures.tree_size(network.state['params']))
                elif operation == "move_params":
                    # only needed for inference, otherwise first train step does this
                    local_shards = max(
                        jax.local_device_count() // self.mesh_shape[1], 1)

                    # delete the optimizer states otherwise it OOMs for some reason
                    # TODO: use ShardedDeviceArray or something to get around this for bigger models
                    del network.state["opt_state"]
                    network.state = network.move_xmap(network.state,
                                                      np.zeros(local_shards))
                    self.output_q.put(None)
                else:
                    raise Exception("Not implemented")
Beispiel #2
0
def save(network, step, bucket, path, mp, aux=None, keep_n=3, delete_old=True):
    assert path
    client = storage.Client()

    if aux is None:
        aux = {}

    try:
        with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
            meta = json.load(f)
    except:
        # create metadata file
        with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
            json.dump({"step": 0, "checkpoints": [], "aux": {}}, f)

    # do sharded checkpoint writing
    start = time.time()
    res = []
    for shard_id in range(mp):
        write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/",
                   shard_id)

    print(f"Wrote checkpoint in {time.time() - start:.06}s")

    with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
        meta = json.load(f)

    meta["step"] = step
    meta["checkpoints"].append(step)
    all_aux = meta.get("aux", {})

    while len(meta["checkpoints"]) > keep_n:
        ckpt_to_delete = meta["checkpoints"].pop(0)

        try:
            del all_aux[str(ckpt_to_delete)]
        except:
            print(f"failed to delete the aux state for {step}")

        if delete_old:
            print(f"deleting checkpoint {ckpt_to_delete}")
            for blob in client.list_blobs(
                    bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
                # print(f"deleting {blob.name}")
                assert path in blob.name
                blob.delete()
        else:
            print(f"keeping checkpoint {ckpt_to_delete}")

    all_aux[step] = aux
    meta["aux"] = all_aux

    with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
        json.dump(meta, f)
    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)

    ckpt_step = meta["checkpoints"][-1]
    print(f"using checkpoint {ckpt_step}")

    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
        network = CausalTransformer(params)

        start = time.time()
        network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
        print(f"network loaded in {time.time() - start:.06}s")

        start = time.time()
        del network.state["opt_state"]

        network.state["params"] = to_bf16(network.state["params"])
        print(f"network converted in {time.time() - start:.06}s")

        for i in range(cores_per_replica):
            write_ckpt(network.state, f"gs://{bucket}/{model_dir}_slim/step_{ckpt_step}/", i)
            print(f"written shard {i}")
 def write_ckpt(self, path, shard):
     write_ckpt(self.state, path, shard)