def test_sharded_tensor_reshard_errors(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
     spec, reshard_spec = specs[0], specs[1]
     enumerable_sharding_spec = EnumerableShardingSpec([
         ShardMetadata(
             shard_offsets=[0, 0],
             shard_sizes=[5, 5],
             placement="rank:0/cuda:0",
         ),
         ShardMetadata(
             shard_offsets=[5, 0],
             shard_sizes=[5, 5],
             placement="rank:1/cuda:1",
         ),
     ])
     st = sharded_tensor.rand(spec, 24, 12)
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec supported for reshard."):
         st.reshard(enumerable_sharding_spec)
     st._local_shards = [st.local_shards()[0], st.local_shards()[0]]
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only single local shard supported for reshard."):
         st.reshard(reshard_spec)
Esempio n. 2
0
 def test_partial_tensor_reshard_errors(self):
     enumerable_sharding_spec = EnumerableShardingSpec(
         [
             ShardMetadata(
                 shard_offsets=[0, 0],
                 shard_sizes=[5, 5],
                 placement="rank:0/cuda:0",
             ),
             ShardMetadata(
                 shard_offsets=[5, 0],
                 shard_sizes=[5, 5],
                 placement="rank:1/cuda:1",
             ),
         ]
     )
     with self.assertRaisesRegex(
         NotImplementedError, "Only ChunkShardingSpec supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM
         )
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX
         )
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     spec = specs[0]
     with self.assertRaisesRegex(
         NotImplementedError, "Only real partial tensor supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat
         )
         self._run_partial_tensor_n_reshard(
             spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat
         )
Esempio n. 3
0
def generate_enumerable_sharding_specs_for_test():
    return [
        EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_sizes=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="rank:3/cuda:3",
            ),
        ])
    ]
Esempio n. 4
0
    def _infer_enum_sharding_spec_case(self):
        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[10, 5],
                placement="cuda:1",
            )
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)

        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0],
                shard_sizes=[16],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[16],
                shard_sizes=[9],
                placement="cuda:1",
            )
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)

        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_sizes=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="rank:3/cuda:3",
            ),
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)
Esempio n. 5
0
    def build_metadata(self,
                       tensor_sizes: torch.Size,
                       tensor_properties: TensorProperties,
                       ) -> ShardedTensorMetadata:
        tensor_num_dim = len(tensor_sizes)
        assert tensor_num_dim == 2, "only support 2-dim tensor for grid sharding"
        shards_metadata = []

        def chunk_num(dim_size, grid_size):
            assert dim_size % grid_size == 0, "only support dim_size mod grid_size == 0"
            return dim_size // grid_size

        row_chunks = chunk_num(tensor_sizes[0], self.grid_size)
        col_chunks = chunk_num(tensor_sizes[1], self.grid_size)

        assert row_chunks * col_chunks == len(self.placements)
        for row_idx in range(row_chunks):
            for col_idx in range(col_chunks):
                shards_metadata.append(
                    ShardMetadata(
                        shard_offsets=[row_idx * self.grid_size, col_idx * self.grid_size],
                        shard_sizes=[self.grid_size, self.grid_size],
                        placement=self.placements[row_idx * row_chunks + col_idx]
                    )
                )
        return ShardedTensorMetadata(
            shards_metadata=shards_metadata,
            size=tensor_sizes,
            tensor_properties=tensor_properties
        )
Esempio n. 6
0
    def test_math_ops_errors(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_lhs = sharded_tensor.rand(spec, (20, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))

        with self.assertRaisesRegex(
            RuntimeError, "Implicit broadcasting not supported"
        ):
            torch.add(sharded_lhs, sharded_rhs)

        spec = EnumerableShardingSpec(
            [
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                ),
            ]
        )

        st = sharded_tensor.rand(spec, 10, 10)

        with self.assertRaisesRegex(RuntimeError, "not supported"):
            torch.add(st, sharded_rhs)
Esempio n. 7
0
    def _init_chunked(
        self,
        dims,
        tensor_init_params: TensorInitParams,
    ):
        current_rank = dist.get_rank(self._process_group)
        sharding_dim = self._sharding_spec.dim  # type: ignore[attr-defined]

        # Validate the sharding spec.
        if not isinstance(sharding_dim, int):
            raise ValueError(
                f"Sharding dim needs to be an integer, found: {sharding_dim}")
        if sharding_dim >= len(dims) or sharding_dim < -len(dims):
            raise ValueError(f"Invalid sharding dim: {sharding_dim}")

        dim_size = dims[sharding_dim]
        remote_devices = self._sharding_spec.placements  # type: ignore[attr-defined]
        chunks = len(remote_devices)
        # split_size computed similar to 'torch.chunk'
        split_size = get_split_size(dim_size, chunks)

        shards_metadata = []
        for idx, remote_device in enumerate(remote_devices):
            rank, local_device = _parse_and_validate_remote_device(
                self._process_group, remote_device)

            # Adjust the sharding dim for this rank.
            sharded_dim_size = get_chunked_dim_size(dim_size, split_size, idx)

            if sharded_dim_size > 0:
                # Build sharding_metadata.

                # deepcopy for modification.
                rank_dims = dims.copy()

                rank_offsets = [0] * len(dims)
                rank_offsets[sharding_dim] = split_size * idx
                rank_dims[sharding_dim] = sharded_dim_size

                shard_metadata = ShardMetadata(rank_offsets, rank_dims,
                                               remote_device)
                shards_metadata.append(shard_metadata)

                # Build the local shard for the current rank if it is involved in the sharding spec.
                if current_rank == rank:
                    # Initialize the local shard.
                    local_shard = _create_tensor_from_params(
                        *rank_dims,
                        local_device=local_device,
                        tensor_init_params=tensor_init_params)
                    self._local_shards.append(
                        Shard(local_shard, shard_metadata))

        # Build overall metadata
        self._metadata = ShardedTensorMetadata(
            shards_metadata,
            dims,
            tensor_init_params.tensor_properties,
        )
Esempio n. 8
0
    def reshard(self, resharding_spec: ShardingSpec) -> ShardedTensor:
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        if not isinstance(resharding_spec, ChunkShardingSpec):
            raise NotImplementedError(
                "Only ChunkShardingSpec supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        if self.local_shard.size(
                sharding_dim) % self.process_group.size() != 0:
            raise ValueError(
                'World size need to divide the length of the dimension.')
        if self.local_shard.is_complex():
            raise NotImplementedError(
                "Only real partial tensor supported for reshard.")

        local_shards = self.local_shard.chunk(self.process_group.size(),
                                              dim=sharding_dim)
        local_result = reduce_scatter(torch.empty_like(local_shards[0]),
                                      list(local_shards),
                                      op=self.reduce_op)

        sharded_tensor_size = self.local_shard.size()
        current_offsets = [0] * len(local_result.size())
        shards = []
        rank = self.process_group.rank()
        for idx, placement in enumerate(
                resharding_spec.placements):  # type: ignore[attr-defined]
            if rank == placement.rank():  # type: ignore[union-attr]
                local_metadata = ShardMetadata(
                    shard_offsets=current_offsets,
                    shard_sizes=list(local_result.size()),
                    placement=placement,
                )
                shards.append(Shard(local_result, local_metadata))
                break
            current_offsets[sharding_dim] += local_result.size(
                sharding_dim)  # type: ignore[index]

        st = ShardedTensor._init_from_local_shards(
            shards,
            tuple(sharded_tensor_size),
            process_group=self.process_group)
        st._sharding_spec = copy.deepcopy(resharding_spec)

        return st
Esempio n. 9
0
def _create_chunk_sharded_tensor(
    tensor: torch.Tensor,
    rank: int,
    world_size: int,
    device_per_node: int,
    pg: dist.ProcessGroup,
) -> ShardedTensor:
    """
    Shard a tensor to chunks along the first dimension. The local rank will gets its
    corresponding chunk as the local shard to create a ShardedTensor.
    """
    chunks = tensor.chunk(world_size, dim=0)
    if len(chunks) > rank:
        local_shard = chunks[rank].clone()
        offsets = [0 for _ in tensor.size()]
        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
        local_shards = [
            Shard.from_tensor_and_offsets(local_shard, offsets, rank)
        ]
    else:
        local_shards = []

    # Create a ShardedTensor without invoking communication.
    chunk_sizes = [list(chunk.size()) for chunk in chunks]
    dim0_offsets = [0] + list(
        itertools.accumulate([chunk_size[0]
                              for chunk_size in chunk_sizes]))[:-1]
    offsets = [0] * (len(chunk_sizes[0]) - 1)
    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
    placements = [
        f"rank:{r}/cuda:{r % device_per_node}" for r in range(len(chunk_sizes))
    ]
    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
    shard_metadata = [
        ShardMetadata(offset, size,
                      placement) for offset, size, placement in zip(
                          chunk_offsets, chunk_sizes, placements)
    ]
    sharded_tensor_metadata = ShardedTensorMetadata(
        shards_metadata=shard_metadata,
        size=tensor.size(),
        tensor_properties=TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=False,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned(),
        ))
    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards,
        sharded_tensor_metadata=sharded_tensor_metadata,
        process_group=pg)
Esempio n. 10
0
    def from_tensor_and_offsets(cls, tensor: torch.Tensor,
                                shard_offsets: List[int], rank: int):
        """
        Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.

        Args:
            tensor(torch.Tensor): Local tensor for the shard.
            shard_offsets(List[int]): List of integers specify the offset
                of the shard on each dimension.
            rank(int): Specify the rank for the shard.
        """
        shard_sizes = list(tensor.size())
        placement = _remote_device(f"rank:{rank}/{str(tensor.device)}")
        shard_meta = ShardMetadata(shard_offsets=shard_offsets,
                                   shard_sizes=shard_sizes,
                                   placement=placement)
        return Shard(tensor, shard_meta)
Esempio n. 11
0
    def _infer_chunk_sharding_spec_case(self, placements, sharding_dim, st_size):
        world_size = len(placements)
        split_size = get_split_size(st_size[sharding_dim], world_size)
        shards_metadata = [None] * world_size
        for idx, placement in enumerate(placements):
            shard_size = copy.deepcopy(st_size)
            offsets = [0] * len(st_size)
            offsets[sharding_dim] = split_size * idx
            shard_size[sharding_dim] = get_chunked_dim_size(st_size[sharding_dim], split_size, idx)
            shards_metadata[placement.rank()] = ShardMetadata(
                shard_offsets=offsets,
                shard_sizes=shard_size,
                placement=placement,
            )

        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, ChunkShardingSpec))
        self.assertEqual(spec.dim, sharding_dim)
        self.assertEqual(spec.placements, placements)
Esempio n. 12
0
def build_reshard_metadata(
    st_size: torch.Size,
    sharding_spec: ShardingSpec,
    world_size: int,
) -> Tuple[List[ShardMetadata], List[int]]:
    """
    Based the given sharding spec, we calculate the offset and local shard size.
    We then build a ShardMetadata on top of the calculation result.

    Args:
        st_size (torch.Size): The size of the sharded tensor.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
            specification describing how the tensor is sharded.
        world_size (int): number of ranks.

    Returns:
        A Tuple of the followings:
            A List[`ShardMetadata`] which contains the metadata for the shard, including
                offsets, lengths and device placement.
            A List[int] which contains the ranks in the order of placement.
    """
    shard_dim = int(sharding_spec.dim)  # type: ignore[attr-defined]
    shards_metadata = [None] * world_size
    ranks = []
    offsets = [0] * len(st_size)
    split_size = get_split_size(st_size[shard_dim], world_size)
    for idx, placement in enumerate(
            sharding_spec.placements):  # type: ignore[attr-defined]
        ranks.append(placement.rank())
        sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size,
                                                idx)
        local_tensor_size = list(st_size)
        local_tensor_size[shard_dim] = sharded_dim_size
        shards_metadata[
            placement.rank()] = ShardMetadata(  # type: ignore[call-overload]
                shard_offsets=copy.deepcopy(offsets),
                shard_sizes=local_tensor_size,
                placement=placement,
            )
        offsets[shard_dim] += sharded_dim_size
    return shards_metadata, ranks  # type: ignore[return-value]
    def test_switch_between_sharded_tensor_to_tensor(self) -> None:
        path = self.get_file_path()
        tensor_size = 32

        specs = [
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[8],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8],
                    shard_sizes=[tensor_size - 8],
                    placement="rank:0",
                ),
            ]),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[10],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[10],
                    shard_sizes=[tensor_size - 10],
                    placement="rank:1",
                ),
            ]),
        ]

        for save_spec in specs:
            for load_spec in specs:
                save_dict = {
                    'sharded':
                    sharded_tensor.rand(save_spec, tensor_size),
                    'replicated':
                    torch.rand(tensor_size, device=f"cpu:{self.rank}")
                }

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)

                # Freaky Friday the tensors
                load_dict = {
                    'sharded': torch.zeros(tensor_size,
                                           device=f"cpu:{self.rank}"),
                    'replicated': sharded_tensor.zeros(load_spec, tensor_size)
                }

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)

                save_dict_sharded = self.load_tensor(save_dict['sharded'])
                load_dict_replicated = self.load_tensor(
                    load_dict['replicated'])

                if dist.get_rank() == 0:
                    self.assertTrue(
                        torch.allclose(save_dict_sharded,
                                       load_dict['sharded']),
                        f"save-spec {save_spec} load-spec {load_spec}")
                    self.assertTrue(
                        torch.allclose(save_dict['replicated'],
                                       load_dict_replicated),
                        f"save-spec {save_spec} load-spec {load_spec}")
    def test_load_with_different_shard_plan(self) -> None:
        path = self.get_file_path()

        # We hardcode the assumption of how many shards are around
        self.assertEqual(self.world_size, dist.get_world_size())

        specs = [
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[2, 0],
                    shard_sizes=[1, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[3, 0],
                    shard_sizes=[3, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[6, 0],
                    shard_sizes=[3, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[9, 0],
                    shard_sizes=[1, 20],
                    placement="rank:0",
                ),
            ]),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[8, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
            ]),
        ]

        for s0 in specs:
            for s1 in specs:
                if s0 == s1:
                    continue

                if dist.get_rank() == 0:
                    shutil.rmtree(path, ignore_errors=True)
                    os.makedirs(path)
                dist.barrier()

                model_to_save = MyShardedModel3(s0)
                model_to_save._register_state_dict_hook(state_dict_hook)
                state_dict_to_save = model_to_save.state_dict()

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=state_dict_to_save,
                                storage_writer=fs_writer)

                dist.barrier()

                model_to_load = MyShardedModel3(s1)
                model_to_load._register_state_dict_hook(state_dict_hook)
                state_dict_to_load_to = model_to_load.state_dict()
                dist.barrier()

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=state_dict_to_load_to,
                                storage_reader=fs_reader)

                dist.barrier()
                store_tensor = self.load_tensor(model_to_save.sharded_tensor)
                dist.barrier()
                load_tensor = self.load_tensor(model_to_load.sharded_tensor)

                if dist.get_rank() == 0:
                    self.assertTrue(torch.allclose(store_tensor, load_tensor),
                                    msg=f"{s0} vs {s1}")
Esempio n. 15
0
def _chunk_to_shard_md(chunk_md: ChunkStorageMetadata) -> ShardMetadata:
    return ShardMetadata(
        shard_offsets=list(chunk_md.offsets),
        shard_sizes=list(chunk_md.sizes)
    )
Esempio n. 16
0
    def _init_from_local_tensor(
        cls,
        local_tensor: torch.Tensor,
        sharding_spec: ShardingSpec,
        *global_size: Sequence[int],
        process_group: dist.ProcessGroup = None,
        init_rrefs=False,
    ) -> "ShardedTensor":
        """
        Initialize a ShardedTensor given only one local tensor, global sharded tensor
        size and sharding spec on each rank.

        Args:
            local_tensor (Tensor): Single tensor of local shard stored in each rank.
            sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how to shard the Tensor.
            global_size (Sequence[int]): Size of the sharded tensor.
            process_group (ProcessGroup, optional): The process group to aggregate on.
                Default: None
            init_rrefs (bool, optional): Whether or not to initialize
                :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
                Need to initialize the RPC Framework if specified as ``True``.
                Default: ``False``.

        Returns:
            A :class:`ShardedTensor` sharded based on the given sharding_spec with local
                tensor stored in the current rank.

        Examples:
            >>> # All tensors below are of torch.int64 type.
            >>> # We have 2 process groups, 2 ranks.
            >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
            >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]))
            >>> local_tensor
            tensor([[1, 2, 3, 4]]) # Rank 0
            tensor([[3, 4, 5, 6]]) # Rank 1
            >>> sharding_dim = 0
            >>> sharding_spec = ChunkShardingSpec(
                    dim=sharding_dim,
                    placements=[
                        "rank:0/cuda:0",
                        "rank:1/cuda:1",
                    ],
                )
            >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
            >>> st
            ShardedTensor(
                ShardedTensorMetadata(
                    shards_metadata=[
                        ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0),
                        ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1),
                    ],
                    size=torch.Size([2, 4])
            )
            >>> st.local_tensor()
            tensor([1, 2, 3, 4]) # Rank 0
            tensor([3, 4, 5, 6]) # Rank 1

        Warning: This API is experimental and subject to change. It lacks of a fully across
                 rank validations, and we only validate the local shard on the current rank.
                 We fully rely on the user to ensure local tensor is sharded based on the
                 sharding spec.
        """
        if not isinstance(sharding_spec, ChunkShardingSpec):
            raise NotImplementedError('Only ChunkShardingSpec is supported.')
        if not local_tensor.is_contiguous():
            raise ValueError('local_tensor is not a contiguous Tensor.')

        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        global_tensor_size = _flatten_tensor_size(global_size)
        sharding_dim = sharding_spec.dim
        split_size = get_split_size(global_tensor_size[sharding_dim],
                                    world_size)  # type: ignore[index]
        current_offsets = [0] * len(global_tensor_size)
        gathered_metadatas = [None] * world_size
        local_shards = []

        for idx, placement in enumerate(sharding_spec.placements):
            chunked_dim_size = get_chunked_dim_size(
                global_tensor_size[sharding_dim],
                split_size,
                idx  # type: ignore[index]
            )
            shard_size = copy.deepcopy(global_tensor_size)
            shard_size[
                sharding_spec.dim] = chunked_dim_size  # type: ignore[index]
            shard_metadata = ShardMetadata(
                shard_offsets=copy.deepcopy(current_offsets),
                shard_sizes=shard_size,
                placement=placement,
            )
            if current_rank == placement.rank():  # type: ignore[union-attr]
                local_shard = local_tensor
            else:
                local_shard = torch.empty(
                    shard_size,
                    device=placement.device(),  # type: ignore[union-attr]
                    requires_grad=local_tensor.requires_grad,
                )
            shards = [
                Shard(
                    tensor=local_shard,
                    metadata=shard_metadata,  # type: ignore[arg-type]
                )
            ]
            if current_rank == placement.rank():  # type: ignore[union-attr]
                local_shards = shards
            gathered_metadatas[placement.rank(
            )] = build_metadata_from_local_shards(  # type: ignore[call-overload, union-attr]
                shards,
                global_tensor_size,
                placement.rank(),
                process_group  # type: ignore[union-attr, arg-type]
            )
            current_offsets[
                sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

        global_sharded_tensor_metadata = build_global_metadata(
            gathered_metadatas)

        return ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards,
            global_sharded_tensor_metadata,
            process_group=process_group,
            init_rrefs=init_rrefs,
        )
Esempio n. 17
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(TypeError,
                                        'bias needs to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile
            error_msg = r"torch function 'linear', with args: (?s).* "
            r"and kwargs: None not supported for ShardedTensor!"
            with self.assertRaisesRegex(RuntimeError, error_msg):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
    def test_enumerable_sharding_spec(self):
        # test valid specs

        # test row-wise sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])
        check_tensor(spec.shards, torch.rand(10, 5).size())

        # test row and column sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[3, 3],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 3],
                shard_sizes=[3, 3],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[3, 0],
                shard_sizes=[3, 3],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[3, 3],
                shard_sizes=[3, 3],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test uneven shard sizes.
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[2, 4],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 4],
                shard_sizes=[4, 2],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[2, 0],
                shard_sizes=[4, 4],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[4, 4],
                shard_sizes=[2, 2],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test invalid sharding
        with self.assertRaisesRegex(ValueError,
                                    'Could not parse remote_device'):
            ShardMetadata(shard_offsets=[0],
                          shard_sizes=[1],
                          placement="cuda:foo")

        with self.assertRaisesRegex(ValueError, 'same number of elements'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_sizes=[1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
            ShardMetadata(shard_offsets=[-1, 0],
                          shard_sizes=[1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_sizes should be >= 0'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_sizes=[-1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
            EnumerableShardingSpec([])

        with self.assertRaisesRegex(ValueError,
                                    'Found inconsistent ranks for shards'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_sizes=[1, 1],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[0, 0, 0],
                              shard_sizes=[1, 1, 1],
                              placement="cpu"),
            ])

        with self.assertRaisesRegex(ValueError, 'Shards.*overlap'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_sizes=[3, 3],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[2, 0],
                              shard_sizes=[3, 3],
                              placement="cpu"),
            ])

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'Rank of tensor is.*but shards rank'):
            check_tensor(spec.shards, torch.rand(10, 10, 10).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'):
            check_tensor(spec.shards, torch.rand(10, 3).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'does not match tensor volume'):
            check_tensor(spec.shards, torch.rand(10, 10).size())