Ejemplo n.º 1
0
def _validate_sharded_tensor(
    tensor_md: ShardedTensorMetadata, checkpoint_md: ShardedTensorStorageMetadata
) -> None:
    # We assume the incoming tensor has being validated during construction

    # To ensure a checkpoint can satisfy loading a ST, we compute the loading
    # plans for all shards and see if they are doable.
    validate_non_overlapping_shards_metadata(
        checkpoint_md.tensor_metadata.shards_metadata
    )

    for shard_md in tensor_md.shards_metadata:
        read_volume = 0
        for storage_md in checkpoint_md.storage_metadata:
            shard_md_from_storage = storage_md.shard_metadata

            if not _check_shard_metadata_pair_overlap(shard_md, shard_md_from_storage):
                continue

            shard_volume = 1
            for (_, _, _, length,) in _shards_get_overlap_region_wrt_saved_tensor(
                saved_shard=shard_md_from_storage, current_shard=shard_md
            ):
                shard_volume *= length
            read_volume += shard_volume

        shard_volume = 1
        for size in shard_md.shard_sizes:
            shard_volume *= size
        if read_volume != shard_volume:
            raise ValueError(
                f"Shard {shard_md} only has {read_volume} available" +
                "elements but needs {shard_volume}"
            )
Ejemplo n.º 2
0
def _prepare_sharded_tensor_read(
        metadata: ShardedTensorStorageMetadata,
        sharded_tensor_out: ShardedTensor) -> List[TensorReadRequest]:
    """
    Prepare sharded tensor read.

    Args:
        metadata: Metadata describing the persisted sharded tensor. Normally,
                  this is generated by func::`_prepare_sharded_tensor_write`.
        sharded_tensor_out: The dest sharded tensor.

    Returns:
        A list of class::`TensorReadRequest`. When fullfilled,
        `sharded_tensor_out`'s local shards load from the persisted sharded
        tensor.
    """
    read_reqs = []
    # this is a naive quadratic algo that can be optimized later
    for shard in sharded_tensor_out.local_shards():
        # scan all mds looking for chunks
        for storage_md in metadata.storage_metadata:
            shard_md_from_storage = storage_md.shard_metadata

            # do they overlap?
            if not _check_shard_metadata_pair_overlap(shard.metadata,
                                                      shard_md_from_storage):
                continue

            storage_key = storage_md.storage_key
            target_tensor = shard.tensor.detach()
            offsets = []
            lengths = []
            for (
                    dim,
                    offset_for_saved_tensor,
                    offset_for_current_tensor,
                    length,
            ) in _shards_get_overlap_region_wrt_saved_tensor(
                    saved_shard=shard_md_from_storage,
                    current_shard=shard.metadata):
                # Note that we do NOT want to make any tensor copy.
                # all operation must be view only
                target_tensor = torch.narrow(target_tensor, dim,
                                             offset_for_current_tensor, length)
                offsets.append(offset_for_saved_tensor)
                lengths.append(length)

            read_reqs.append(
                TensorReadRequest(
                    tensor=target_tensor,
                    storage_key=storage_key,
                    offsets=tuple(offsets),
                    lengths=tuple(lengths),
                ))
    return read_reqs
Ejemplo n.º 3
0
def _prepare_generic_tensor_read(
    fqn: str,
    checkpoint_shards: List[ChunkStorageMetadata],
    local_shards: List[Shard],
    storage_metadata: Dict[MetadataIndex, str]
) -> List[TensorReadRequest]:
    read_reqs = []
    # this is a naive quadratic algo that can be optimized later
    for shard in local_shards:
        # scan all mds looking for chunks
        for storage_md in checkpoint_shards:
            shard_md_from_storage = _chunk_to_shard_md(storage_md)

            # do they overlap?
            if not _check_shard_metadata_pair_overlap(
                shard.metadata, shard_md_from_storage
            ):
                continue

            storage_key = storage_metadata[MetadataIndex(fqn, storage_md.offsets)]
            target_tensor = shard.tensor.detach()
            offsets = []
            lengths = []
            for (
                dim,
                offset_for_saved_tensor,
                offset_for_current_tensor,
                length,
            ) in _shards_get_overlap_region_wrt_saved_tensor(
                saved_shard=shard_md_from_storage, current_shard=shard.metadata
            ):
                # Note that we do NOT want to make any tensor copy.
                # all operation must be view only
                target_tensor = torch.narrow(
                    target_tensor, dim, offset_for_current_tensor, length
                )
                offsets.append(offset_for_saved_tensor)
                lengths.append(length)

            read_reqs.append(
                TensorReadRequest(
                    tensor=target_tensor,
                    storage_key=storage_key,
                    offsets=tuple(offsets),
                    lengths=tuple(lengths),
                )
            )
    return read_reqs