Exemplo n.º 1
0
    def build_metadata(
        self,
        tensor_sizes: torch.Size,
        tensor_properties: sharded_tensor_meta.TensorProperties,
    ) -> sharded_tensor_meta.ShardedTensorMetadata:
        tensor_num_dim = len(tensor_sizes)

        self._verify_dim(self.dim)
        if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim:  # type: ignore[operator]
            raise ValueError(f"Invalid sharding dim: {self.dim}")

        shards_metadata = []
        sharding_dim_size = tensor_sizes[self.dim]  # type: ignore[index]
        chunks = len(self.placements)
        split_size = get_split_size(sharding_dim_size, chunks)
        for idx, placement in enumerate(self.placements):
            # generate ShardMetadata for each placement device
            chunked_dim_size = get_chunked_dim_size(sharding_dim_size,
                                                    split_size, idx)
            if chunked_dim_size > 0:
                shard_size = list(tensor_sizes)
                current_offsets = [0] * tensor_num_dim
                current_offsets[
                    self.dim] = split_size * idx  # type: ignore[index]
                shard_size[self.dim] = chunked_dim_size  # type: ignore[index]

                shard_metadata = ShardMetadata(
                    shard_offsets=current_offsets,
                    shard_sizes=shard_size,
                    placement=placement,
                )
                shards_metadata.append(shard_metadata)

        return sharded_tensor_meta.ShardedTensorMetadata(
            shards_metadata, tensor_sizes, tensor_properties)
Exemplo n.º 2
0
 def build_metadata(
     self,
     tensor_sizes: torch.Size,
     tensor_properties: sharded_tensor_meta.TensorProperties,
 ) -> sharded_tensor_meta.ShardedTensorMetadata:
     # check if shards form a valid tensor
     check_tensor(self.shards, tensor_sizes)
     return sharded_tensor_meta.ShardedTensorMetadata(
         self.shards, tensor_sizes, tensor_properties)