Пример #1
0
def fast_approx_rand(out: FloatTensorType) -> None:
    out = out.flatten()
    numel = out.numel()
    if numel < 1_000_003:
        torch.randn(numel, out=out)
        return
    t = torch.randn(1_000_003)
    excess = numel % 1_000_003
    # Using just `-excess` would give bad results when excess == 0.
    out[:numel - excess].view(-1, 1_000_003)[...] = t
    out[numel - excess:] = t[:excess]
    def store(
        self,
        entity: EntityName,
        part: Partition,
        embs: FloatTensorType,
        optim_state: Optional[bytes],
    ) -> None:
        key = "%s_%s" % (entity, part)
        futs = []

        # embs sharded across all servers.
        num_clients = len(self._clients)
        flattened_embs = embs.flatten()
        numel_per_shard = self.shard_size // flattened_embs.element_size()
        # Space out the start of each partition's chunks of servers evenly.
        offset = (part * num_clients) // self.entities[entity].num_partitions
        for (idx, flattened_embs_shard) in enumerate(
                torch.split(flattened_embs, numel_per_shard)):
            client_idx = (idx + offset) % num_clients
            shard_key = key + f"__embs_{idx}"
            fut = self._async_store(client_idx, shard_key,
                                    flattened_embs_shard)
            futs.append(fut)

        # optim_state unsharded due to its smaller size.
        if optim_state is not None:
            client_idx = offset % num_clients
            optim_key = key + "__optim"
            optim_state_tensor = bytes_to_bytetensor(optim_state)
            optim_state_fut = self._async_store(client_idx, optim_key,
                                                optim_state_tensor)

        t0 = time.monotonic()

        for fut in futs:
            fut.result()

        t1 = time.monotonic()

        if optim_state is not None:
            optim_state_fut.result()

        t2 = time.monotonic()

        if self.log_stats:
            embs_size = embs.numel() * embs.element_size()
            embs_time = t1 - t0
            embs_speed = embs_size / embs_time
            message = (
                f"Stored (entity {entity}, partition {part}). Embs of size {embs_size} "
                f"bytes stored in {embs_time} seconds ({embs_speed:,.0f} B/s)."
            )

            if optim_state is not None:
                optim_state_size = (optim_state_tensor.numel() *
                                    optim_state_tensor.element_size())
                optim_state_time = t2 - t1
                optim_state_speed = optim_state_size / optim_state_time
                message += (
                    f"optim_state of size {optim_state_size} bytes stored in "
                    f"{optim_state_time} seconds ({optim_state_speed:,.0f} B/s)."
                )

            logger.info(message)