def test_sharded_tensor_reshard_errors(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
     spec, reshard_spec = specs[0], specs[1]
     enumerable_sharding_spec = EnumerableShardingSpec([
         ShardMetadata(
             shard_offsets=[0, 0],
             shard_sizes=[5, 5],
             placement="rank:0/cuda:0",
         ),
         ShardMetadata(
             shard_offsets=[5, 0],
             shard_sizes=[5, 5],
             placement="rank:1/cuda:1",
         ),
     ])
     st = sharded_tensor.rand(spec, 24, 12)
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec supported for reshard."):
         st.reshard(enumerable_sharding_spec)
     st._local_shards = [st.local_shards()[0], st.local_shards()[0]]
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only single local shard supported for reshard."):
         st.reshard(reshard_spec)
Example #2
0
 def test_partial_tensor_reshard_errors(self):
     enumerable_sharding_spec = EnumerableShardingSpec(
         [
             ShardMetadata(
                 shard_offsets=[0, 0],
                 shard_sizes=[5, 5],
                 placement="rank:0/cuda:0",
             ),
             ShardMetadata(
                 shard_offsets=[5, 0],
                 shard_sizes=[5, 5],
                 placement="rank:1/cuda:1",
             ),
         ]
     )
     with self.assertRaisesRegex(
         NotImplementedError, "Only ChunkShardingSpec supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM
         )
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX
         )
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     spec = specs[0]
     with self.assertRaisesRegex(
         NotImplementedError, "Only real partial tensor supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat
         )
         self._run_partial_tensor_n_reshard(
             spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat
         )
Example #3
0
def generate_enumerable_sharding_specs_for_test():
    return [
        EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_sizes=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="rank:3/cuda:3",
            ),
        ])
    ]
Example #4
0
 def _create_enumerate_spec(self, tensor):
     # Since placement is not used, always set placement to rank0 to mimic
     # the actual usage.
     metadata = [
         ShardMetadata([0], [101], placement="rank0/cuda:0"),
         ShardMetadata([101], [900], placement="rank0/cuda:0"),
     ]
     return EnumerableShardingSpec(metadata)
Example #5
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        *global_size,
        process_group=None,
        init_rrefs=False,
    ):
        # STEP 1: Validate the Shardmetadatas locally
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
        global_tensor_size = _flatten_tensor_size(global_size)

        if len(local_shards) > 0:
            local_sharded_tensor_metadata = \
                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)

        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
        # metadata by gathering local ShardedTensorMetadata
        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
        if world_size > 1:
            gathered_metadatas = [None for _ in range(world_size)]

            dist.all_gather_object(gathered_metadatas,
                                   local_sharded_tensor_metadata,
                                   group=process_group)
        else:
            gathered_metadatas = [local_sharded_tensor_metadata]

        global_sharded_tensor_metadata = build_global_metadata(
            gathered_metadatas)

        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
        # prepare initialization
        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        # add to metadata and local_shards
        sharded_tensor._metadata = global_sharded_tensor_metadata
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(
            global_sharded_tensor_metadata.shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Example #6
0
    def test_math_ops_errors(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_lhs = sharded_tensor.rand(spec, (20, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))

        with self.assertRaisesRegex(
            RuntimeError, "Implicit broadcasting not supported"
        ):
            torch.add(sharded_lhs, sharded_rhs)

        spec = EnumerableShardingSpec(
            [
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                ),
            ]
        )

        st = sharded_tensor.rand(spec, 10, 10)

        with self.assertRaisesRegex(RuntimeError, "not supported"):
            torch.add(st, sharded_rhs)
    def test_switch_between_sharded_tensor_to_tensor(self) -> None:
        path = self.get_file_path()
        tensor_size = 32

        specs = [
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[8],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8],
                    shard_sizes=[tensor_size - 8],
                    placement="rank:0",
                ),
            ]),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[10],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[10],
                    shard_sizes=[tensor_size - 10],
                    placement="rank:1",
                ),
            ]),
        ]

        for save_spec in specs:
            for load_spec in specs:
                save_dict = {
                    'sharded':
                    sharded_tensor.rand(save_spec, tensor_size),
                    'replicated':
                    torch.rand(tensor_size, device=f"cpu:{self.rank}")
                }

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)

                # Freaky Friday the tensors
                load_dict = {
                    'sharded': torch.zeros(tensor_size,
                                           device=f"cpu:{self.rank}"),
                    'replicated': sharded_tensor.zeros(load_spec, tensor_size)
                }

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)

                save_dict_sharded = self.load_tensor(save_dict['sharded'])
                load_dict_replicated = self.load_tensor(
                    load_dict['replicated'])

                if dist.get_rank() == 0:
                    self.assertTrue(
                        torch.allclose(save_dict_sharded,
                                       load_dict['sharded']),
                        f"save-spec {save_spec} load-spec {load_spec}")
                    self.assertTrue(
                        torch.allclose(save_dict['replicated'],
                                       load_dict_replicated),
                        f"save-spec {save_spec} load-spec {load_spec}")
    def test_load_with_different_shard_plan(self) -> None:
        path = self.get_file_path()

        # We hardcode the assumption of how many shards are around
        self.assertEqual(self.world_size, dist.get_world_size())

        specs = [
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[2, 0],
                    shard_sizes=[1, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[3, 0],
                    shard_sizes=[3, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[6, 0],
                    shard_sizes=[3, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[9, 0],
                    shard_sizes=[1, 20],
                    placement="rank:0",
                ),
            ]),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[8, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
            ]),
        ]

        for s0 in specs:
            for s1 in specs:
                if s0 == s1:
                    continue

                if dist.get_rank() == 0:
                    shutil.rmtree(path, ignore_errors=True)
                    os.makedirs(path)
                dist.barrier()

                model_to_save = MyShardedModel3(s0)
                model_to_save._register_state_dict_hook(state_dict_hook)
                state_dict_to_save = model_to_save.state_dict()

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=state_dict_to_save,
                                storage_writer=fs_writer)

                dist.barrier()

                model_to_load = MyShardedModel3(s1)
                model_to_load._register_state_dict_hook(state_dict_hook)
                state_dict_to_load_to = model_to_load.state_dict()
                dist.barrier()

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=state_dict_to_load_to,
                                storage_reader=fs_reader)

                dist.barrier()
                store_tensor = self.load_tensor(model_to_save.sharded_tensor)
                dist.barrier()
                load_tensor = self.load_tensor(model_to_load.sharded_tensor)

                if dist.get_rank() == 0:
                    self.assertTrue(torch.allclose(store_tensor, load_tensor),
                                    msg=f"{s0} vs {s1}")
Example #9
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(TypeError,
                                        'bias needs to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile
            error_msg = r"torch function 'linear', with args: (?s).* "
            r"and kwargs: None not supported for ShardedTensor!"
            with self.assertRaisesRegex(RuntimeError, error_msg):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
    def test_enumerable_sharding_spec(self):
        # test valid specs

        # test row-wise sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])
        check_tensor(spec.shards, torch.rand(10, 5).size())

        # test row and column sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[3, 3],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 3],
                shard_sizes=[3, 3],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[3, 0],
                shard_sizes=[3, 3],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[3, 3],
                shard_sizes=[3, 3],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test uneven shard sizes.
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[2, 4],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 4],
                shard_sizes=[4, 2],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[2, 0],
                shard_sizes=[4, 4],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[4, 4],
                shard_sizes=[2, 2],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test invalid sharding
        with self.assertRaisesRegex(ValueError,
                                    'Could not parse remote_device'):
            ShardMetadata(shard_offsets=[0],
                          shard_sizes=[1],
                          placement="cuda:foo")

        with self.assertRaisesRegex(ValueError, 'same number of elements'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_sizes=[1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
            ShardMetadata(shard_offsets=[-1, 0],
                          shard_sizes=[1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_sizes should be >= 0'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_sizes=[-1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
            EnumerableShardingSpec([])

        with self.assertRaisesRegex(ValueError,
                                    'Found inconsistent ranks for shards'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_sizes=[1, 1],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[0, 0, 0],
                              shard_sizes=[1, 1, 1],
                              placement="cpu"),
            ])

        with self.assertRaisesRegex(ValueError, 'Shards.*overlap'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_sizes=[3, 3],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[2, 0],
                              shard_sizes=[3, 3],
                              placement="cpu"),
            ])

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'Rank of tensor is.*but shards rank'):
            check_tensor(spec.shards, torch.rand(10, 10, 10).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'):
            check_tensor(spec.shards, torch.rand(10, 3).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'does not match tensor volume'):
            check_tensor(spec.shards, torch.rand(10, 10).size())
Example #11
0
    def _init_from_local_shards_and_global_metadata(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        process_group=None,
        init_rrefs=False,
    ) -> "ShardedTensor":
        """
        Initialize a ShardedTensor with local shards and a global
        ShardedTensorMetadata built on each rank.

        Warning: This API is experimental and subject to change. It does
                 not do cross rank validations, and fully rely on the user
                 for the correctness of sharded_tensor_metadata on each rank
        """
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)

        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError(
                'Only torch.strided layout is currently supported')

        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        sharded_tensor._metadata = sharded_tensor_metadata

        local_shard_metadatas = []

        def _raise_if_mismatch(expected,
                               actual,
                               prop_name,
                               rank,
                               is_property=False):
            tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata"
            if expected != actual:
                raise ValueError(
                    f"Local shards' tensor {prop_name} property is incompatible with "
                    f"{tensor_property_or_metadata} on rank {rank}: "
                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
                    f"local shard tensor {prop_name}={actual}.")

        # collect local shard metadatas from the global sharded_tensor_metadata
        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_metadata.placement)

            if current_rank == rank:
                local_shard_metadatas.append(shard_metadata)

        if len(local_shards) != len(local_shard_metadatas):
            raise RuntimeError(
                f'Number of local shards ({len(local_shards)}) does not match number of local '
                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
                f'on rank ({current_rank}) ')

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_meta.placement)

            # validate if shard_meta in the metadatas collected from sharded_tensor_metadata
            assert shard_meta in local_shard_metadatas, \
                "local shard metadata not in sharded_tensor_metadata!"

            _raise_if_mismatch(tensor_properties.layout,
                               local_shard_tensor.layout, "layout",
                               current_rank, True)
            if not local_shard_tensor.is_contiguous():
                raise ValueError(
                    'Only torch.contiguous_format memory_format is currently supported'
                )

            _raise_if_mismatch(shard_meta.shard_sizes,
                               list(local_shard_tensor.size()), "size",
                               current_rank)
            _raise_if_mismatch(tensor_properties.pin_memory,
                               local_shard_tensor.is_pinned(), "pin_memory",
                               current_rank, True)
            _raise_if_mismatch(local_device, local_shard_tensor.device,
                               "device", current_rank)
            _raise_if_mismatch(tensor_properties.dtype,
                               local_shard_tensor.dtype, "dtype", current_rank,
                               True)
            _raise_if_mismatch(tensor_properties.requires_grad,
                               local_shard_tensor.requires_grad,
                               "requires_grad", current_rank, True)

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor