Beispiel #1
0
    def test_new_group(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cuda:3",
            ),
        ])

        pg = dist.new_group(ranks=[1, 2, 3])

        if self.rank >= 1:
            sharded_tensor = _sharded_tensor.empty(spec,
                                                   10,
                                                   5,
                                                   process_group=pg)
            self.assertEqual((10, 5), sharded_tensor.size())
            if self.rank == 1 or self.rank == 3:
                # Verify local shard.
                local_shard = sharded_tensor.local_shards()[0]
                self.assertEqual(torch.device(f'cuda:{self.rank}'),
                                 local_shard.tensor.device)
                self.assertEqual((5, 5), local_shard.tensor.size())

                # Verify local shard metadata.
                self.assertEqual((self.rank // 2 * 5, 0),
                                 local_shard.metadata.shard_offsets)
                self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
                self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}',
                                 local_shard.metadata.placement)

            # Verify global metadata.
            sharding_metadata = sharded_tensor.sharding_metadata()
            self.assertEqual(2, len(sharding_metadata))
            for rank, shard_metadata in enumerate(sharding_metadata):
                self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
                self.assertEqual((5, 5), shard_metadata.shard_lengths)
                self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}',
                                 shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            if self.rank == 1 or self.rank == 3:
                self.assertEqual(1, len(remote_shards))
            else:
                self.assertEqual(2, len(remote_shards))

            owners = {}
            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))

                for remote_shard in shards:
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    shard = remote_shard.to_here()
                    self.assertEqual((5, 5), shard.tensor.size())
    def test_partial_world_size(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 5)
        self.assertEqual((10, 5), sharded_tensor.size())
        if self.rank <= 1:
            self.assertEqual(1, len(sharded_tensor.local_shards()))
        else:
            self.assertEqual(0, len(sharded_tensor.local_shards()))

        if self.rank <= 1:
            # Verify local shard.
            local_shard = sharded_tensor.local_shards()[0]
            self.assertEqual(torch.device(f'cuda:{self.rank}'),
                             local_shard.tensor.device)
            self.assertEqual((5, 5), local_shard.tensor.size())

            # Verify local shard metadata.
            self.assertEqual((self.rank * 5, 0),
                             local_shard.metadata.shard_offsets)
            self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
            self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                             local_shard.metadata.placement)

        # Verify global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(2, len(shards_metadata))
        for rank, shard_metadata in enumerate(shards_metadata):
            self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{rank}/cuda:{rank}',
                             shard_metadata.placement)

        # Validate remote shards.
        remote_shards = sharded_tensor.remote_shards
        if self.rank <= 1:
            self.assertEqual(1, len(remote_shards))
        else:
            self.assertEqual(2, len(remote_shards))

        for rpc_rank, shards in remote_shards.items():
            self.assertEqual(1, len(shards))

            for remote_shard in shards:
                self.assertEqual(rpc_rank, remote_shard.owner().id)
                shard = remote_shard.to_here()
                self.assertEqual((5, 5), shard.tensor.size())
    def test_with_rpc_names(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="worker0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="worker1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="worker2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="worker3/cuda:3",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        self.assertEqual((10, 10), sharded_tensor.size())
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        # Verify local shard.
        local_shard = sharded_tensor.local_shards()[0]
        self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device)
        self.assertEqual((5, 5), local_shard.tensor.size())

        # Verify local shard metadata.
        self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets)
        self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
        self.assertEqual(f'worker{self.rank}/cuda:{self.rank}', local_shard.metadata.placement)

        # Verify global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(4, len(shards_metadata))
        for rank, shard_metadata in enumerate(shards_metadata):
            self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'worker{rank}/cuda:{rank}', shard_metadata.placement)

        # Validate remote shards.
        remote_shards = sharded_tensor.remote_shards
        self.assertEqual(3, len(remote_shards))

        for rpc_rank, shards in remote_shards.items():
            self.assertEqual(1, len(shards))
            for remote_shard in shards:
                self.assertEqual(rpc_rank, remote_shard.owner().id)
                shard = remote_shard.to_here()
                self.assertEqual((5, 5), shard.tensor.size())
Beispiel #4
0
    def test_multiple_local_shards(self):
        self.init_pg()

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        self.assertEqual((10, 10), sharded_tensor.size())

        if self.rank <= 1:
            self.assertEqual(2, len(sharded_tensor.local_shards()))

            # Verify local shards.
            for idx, local_shard in enumerate(sharded_tensor.local_shards()):
                self.assertEqual(torch.device(f'cuda:{self.rank}'),
                                 local_shard.tensor.device)
                self.assertEqual((5, 5), local_shard.tensor.size())

                # Verify local shard metadata.
                self.assertEqual((idx * 5, self.rank * 5),
                                 local_shard.metadata.shard_offsets)
                self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
                self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                                 local_shard.metadata.placement)
        else:
            self.assertEqual(0, len(sharded_tensor.local_shards()))

        # Verify global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(4, len(sharding_metadata))
        for shard_rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5),
                             shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}',
                             shard_metadata.placement)
Beispiel #5
0
    def test_grid_sharding(self):
        self.init_pg()

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:3/cuda:3",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        self.assertEqual((10, 10), sharded_tensor.size())
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        # Verify local shard.
        local_shard = sharded_tensor.local_shards()[0]
        self.assertEqual(torch.device(f'cuda:{self.rank}'),
                         local_shard.tensor.device)
        self.assertEqual((5, 5), local_shard.tensor.size())

        # Verify local shard metadata.
        self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5),
                         local_shard.metadata.shard_offsets)
        self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
        self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                         local_shard.metadata.placement)

        # Verify global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(4, len(sharding_metadata))
        for rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual((rank // 2 * 5, (rank % 2) * 5),
                             shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{rank}/cuda:{rank}',
                             shard_metadata.placement)
Beispiel #6
0
    def _init_chunked(
        self,
        dims,
        tensor_init_params: TensorInitParams,
    ):
        current_rank = dist.get_rank(self._process_group)
        sharding_dim = self._sharding_spec.dim  # type: ignore[attr-defined]

        # Validate the sharding spec.
        if not isinstance(sharding_dim, int):
            raise ValueError(
                f"Sharding dim needs to be an integer, found: {sharding_dim}")
        if sharding_dim >= len(dims) or sharding_dim < -len(dims):
            raise ValueError(f"Invalid sharding dim: {sharding_dim}")

        dim_size = dims[sharding_dim]
        remote_devices = self._sharding_spec.placements  # type: ignore[attr-defined]
        chunks = len(remote_devices)
        # split_size computed similar to 'torch.chunk'
        split_size = (dim_size + chunks - 1) // chunks

        shards_metadata = []
        for idx, remote_device in enumerate(remote_devices):
            rank, local_device = self._parse_and_validate_remote_device(
                remote_device)

            # Adjust the sharding dim for this rank.
            sharded_dim_size = min(dim_size, split_size *
                                   (idx + 1)) - split_size * idx

            if sharded_dim_size > 0:
                # Build sharding_metadata.

                # deepcopy for modification.
                rank_dims = dims.copy()

                rank_offsets = [0] * len(dims)
                rank_offsets[sharding_dim] = split_size * idx
                rank_dims[sharding_dim] = sharded_dim_size

                shard_metadata = ShardMetadata(rank_offsets, rank_dims,
                                               remote_device)
                shards_metadata.append(shard_metadata)

                # Build the local shard for the current rank if it is involved in the sharding spec.
                if current_rank == rank:
                    # Initialize the local shard.
                    local_shard = _create_tensor_from_params(
                        *rank_dims,
                        local_device=local_device,
                        tensor_init_params=tensor_init_params)
                    self._local_shards.append(
                        Shard(local_shard, shard_metadata))

        # Build overall metadata
        self._metadata = ShardedTensorMetadata(
            shards_metadata,
            dims,
            tensor_init_params.tensor_properties,
        )
Beispiel #7
0
    def from_tensor_and_offsets(cls, tensor: torch.Tensor,
                                shard_offsets: List[int], rank: int):
        """
        Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.

        Args:
            tensor(torch.Tensor): Local tensor for the shard.
            shard_offsets(List[int]): List of integers specify the offset
                of the shard on each dimension.
            rank(int): Specify the rank for the shard.
        """
        shard_lengths = list(tensor.size())
        placement = _remote_device(f"rank:{rank}/{str(tensor.device)}")
        shard_meta = ShardMetadata(shard_offsets=shard_offsets,
                                   shard_lengths=shard_lengths,
                                   placement=placement)
        return Shard(tensor, shard_meta)
Beispiel #8
0
    def _init_chunked(
        self,
        dtype,
        layout,
        requires_grad,
        pin_memory,
        memory_format,
    ):
        current_rank = dist.get_rank(self._process_group)
        sharding_dim = self._sharding_spec.dim  # type: ignore[attr-defined]

        # Validate the sharding spec.
        if not isinstance(sharding_dim, int):
            raise ValueError(
                f"Sharding dim needs to be an integer, found: {sharding_dim}"
            )
        if sharding_dim >= len(self._dims) or sharding_dim < -len(self._dims):
            raise ValueError(f"Invalid sharding dim: {sharding_dim}")

        dim_size = self._dims[sharding_dim]
        devices = self._sharding_spec.placements  # type: ignore[attr-defined]
        chunks = len(devices)
        # split_size computed similar to 'torch.chunk'
        split_size = (dim_size + chunks - 1) // chunks

        for idx, device in enumerate(devices):
            if not is_valid_device(device):
                raise ValueError(f"{device} is not a valid device")

            rank, local_device = self._parse_and_validate_remote_device(device)

            # Adjust the sharding dim for this rank.
            sharded_dim_size = min(dim_size, split_size * (idx + 1)) - split_size * idx

            if sharded_dim_size > 0:
                # Build sharding_metadata.

                # deepcopy for modification.
                rank_dims = self._dims.copy()

                rank_offsets = [0] * len(self._dims)
                rank_offsets[sharding_dim] = split_size * idx
                rank_dims[sharding_dim] = sharded_dim_size

                shard_metadata = ShardMetadata(rank_offsets, rank_dims, device)
                self._sharding_metadata.append(shard_metadata)

                # Build the local shard for the current rank if it is involved in the sharding spec.
                if current_rank == rank:
                    # Initialize the local shard.
                    local_shard = torch.empty(
                        *rank_dims,
                        dtype=dtype,
                        layout=layout,
                        device=local_device,
                        requires_grad=requires_grad,
                        memory_format=memory_format,
                        pin_memory=pin_memory,
                    )

                    self._local_shards.append(Shard(local_shard, shard_metadata))
Beispiel #9
0
    def test_enumerable_sharding_spec(self):
        # test valid specs

        # test row-wise sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])
        spec.check_tensor(torch.rand(10, 5).size())

        # test row and column sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[3, 3],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 3],
                shard_lengths=[3, 3],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[3, 0],
                shard_lengths=[3, 3],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[3, 3],
                shard_lengths=[3, 3],
                placement="cuda:3",
            ),
        ])
        spec.check_tensor(torch.rand(6, 6).size())

        # test uneven shard sizes.
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[2, 4],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 4],
                shard_lengths=[4, 2],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[2, 0],
                shard_lengths=[4, 4],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[4, 4],
                shard_lengths=[2, 2],
                placement="cuda:3",
            ),
        ])
        spec.check_tensor(torch.rand(6, 6).size())

        # test invalid sharding
        with self.assertRaisesRegex(ValueError, 'not a valid device'):
            ShardMetadata(shard_offsets=[0],
                          shard_lengths=[1],
                          placement="cuda:foo")

        with self.assertRaisesRegex(ValueError, 'same number of elements'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_lengths=[1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
            ShardMetadata(shard_offsets=[-1, 0],
                          shard_lengths=[1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_lengths=[0, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
            EnumerableShardingSpec([])

        with self.assertRaisesRegex(ValueError,
                                    'Found inconsistent ranks for shards'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_lengths=[1, 1],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[0, 0, 0],
                              shard_lengths=[1, 1, 1],
                              placement="cpu"),
            ])

        with self.assertRaisesRegex(ValueError, 'Shards.*overlap'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_lengths=[3, 3],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[2, 0],
                              shard_lengths=[3, 3],
                              placement="cpu"),
            ])

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'Rank of tensor is.*but shards rank'):
            spec.check_tensor(torch.rand(10, 10, 10).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'):
            spec.check_tensor(torch.rand(10, 3).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'does not match tensor volume'):
            spec.check_tensor(torch.rand(10, 10).size())
Beispiel #10
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(
                    TypeError, 'input and bias need to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            with self.assertRaisesRegex(
                    ValueError,
                    'Only ChunkShardingSpec supported for ShardedTensor ops!'):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
    def test_uneven_shards(self):
        self.init_pg()

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[2, 4],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 4],
                shard_lengths=[4, 2],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[2, 0],
                shard_lengths=[4, 4],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[4, 4],
                shard_lengths=[2, 2],
                placement="rank:3/cuda:3",
            ),
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 6, 6)
        self.assertEqual((6, 6), sharded_tensor.size())
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        def verify_size(rank, tensor_dims):
            if rank == 0:
                self.assertEqual((2, 4), tensor_dims)
            elif rank == 1:
                self.assertEqual((4, 2), tensor_dims)
            elif rank == 2:
                self.assertEqual((4, 4), tensor_dims)
            elif rank == 3:
                self.assertEqual((2, 2), tensor_dims)

        def verify_offsets(rank, offsets):
            if rank == 0:
                self.assertEqual((0, 0), offsets)
            elif rank == 1:
                self.assertEqual((0, 4), offsets)
            elif rank == 2:
                self.assertEqual((2, 0), offsets)
            elif rank == 3:
                self.assertEqual((4, 4), offsets)

        # Verify local shard.
        local_shard = sharded_tensor.local_shards()[0]
        self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device)
        verify_size(self.rank, local_shard.tensor.size())

        # Verify local shard metadata.
        verify_offsets(self.rank, local_shard.metadata.shard_offsets)
        verify_size(self.rank, local_shard.metadata.shard_lengths)
        self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement)

        # Verify global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(4, len(shards_metadata))
        for rank, shard_metadata in enumerate(shards_metadata):
            verify_offsets(rank, shard_metadata.shard_offsets)
            verify_size(rank, shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
    def test_sharded_tensor_metadata(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:3/cuda:3",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size)
        self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
        self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
        self.assertEqual(False, sharded_tensor_metadata.requires_grad)
        self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format)
        self.assertEqual(False, sharded_tensor_metadata.pin_memory)

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.requires_grad)

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.double, sharded_tensor_metadata.dtype)

        # Need CPU for pin_memory
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cpu",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cpu",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cpu",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:3/cpu",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.pin_memory)