def _infer_enum_sharding_spec_case(self):
        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[10, 5],
                placement="cuda:1",
            )
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)

        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0],
                shard_sizes=[16],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[16],
                shard_sizes=[9],
                placement="cuda:1",
            )
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)

        shards_metadata = [
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_sizes=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_sizes=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_sizes=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_sizes=[5, 5],
                placement="rank:3/cuda:3",
            ),
        ]
        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, EnumerableShardingSpec))
        self.assertEqual(spec.shards, shards_metadata)
Exemple #2
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        *global_size,
        process_group=None,
        init_rrefs=False,
    ):
        # STEP 1: Validate the Shardmetadatas locally
        process_group = (
            process_group
            if process_group is not None
            else distributed_c10d._get_default_group()
        )
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
        global_tensor_size = _flatten_tensor_size(global_size)

        if len(local_shards) > 0:
            local_sharded_tensor_metadata = \
                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)

        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
        # metadata by gathering local ShardedTensorMetadata
        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
        if world_size > 1:
            gathered_metadatas = [None for _ in range(world_size)]

            dist.all_gather_object(
                gathered_metadatas,
                local_sharded_tensor_metadata,
                group=process_group
            )
        else:
            gathered_metadatas = [local_sharded_tensor_metadata]

        global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
        tensor_properties = global_sharded_tensor_metadata.tensor_properties

        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
        # prepare initialization
        spec = shard_spec._infer_sharding_spec_from_shards_metadata(
            global_sharded_tensor_metadata.shards_metadata
        )
        sharded_tensor = cls.__new__(cls,
                                     spec,
                                     global_sharded_tensor_metadata.size,
                                     dtype=tensor_properties.dtype,
                                     layout=tensor_properties.layout,
                                     pin_memory=tensor_properties.pin_memory,
                                     requires_grad=tensor_properties.requires_grad)
        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)

        # attach local_shards to the ShardedTensor created
        sharded_tensor._local_shards = local_shards

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Exemple #3
0
    def _infer_chunk_sharding_spec_case(self, placements, sharding_dim, st_size):
        world_size = len(placements)
        split_size = get_split_size(st_size[sharding_dim], world_size)
        shards_metadata = [None] * world_size
        for idx, placement in enumerate(placements):
            shard_size = copy.deepcopy(st_size)
            offsets = [0] * len(st_size)
            offsets[sharding_dim] = split_size * idx
            shard_size[sharding_dim] = get_chunked_dim_size(st_size[sharding_dim], split_size, idx)
            shards_metadata[placement.rank()] = ShardMetadata(
                shard_offsets=offsets,
                shard_sizes=shard_size,
                placement=placement,
            )

        spec = _infer_sharding_spec_from_shards_metadata(shards_metadata)
        self.assertTrue(isinstance(spec, ChunkShardingSpec))
        self.assertEqual(spec.dim, sharding_dim)
        self.assertEqual(spec.placements, placements)
Exemple #4
0
    def _init_from_local_shards_and_global_metadata(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        process_group=None,
        init_rrefs=False,
    ) -> "ShardedTensor":
        """
        Initialize a ShardedTensor with local shards and a global
        ShardedTensorMetadata built on each rank.

        Warning: This API is experimental and subject to change. It does
                 not do cross rank validations, and fully rely on the user
                 for the correctness of sharded_tensor_metadata on each rank
        """
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)

        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError(
                'Only torch.strided layout is currently supported')

        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        sharded_tensor._metadata = sharded_tensor_metadata

        local_shard_metadatas = []

        def _raise_if_mismatch(expected,
                               actual,
                               prop_name,
                               rank,
                               is_property=False):
            tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata"
            if expected != actual:
                raise ValueError(
                    f"Local shards' tensor {prop_name} property is incompatible with "
                    f"{tensor_property_or_metadata} on rank {rank}: "
                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
                    f"local shard tensor {prop_name}={actual}.")

        # collect local shard metadatas from the global sharded_tensor_metadata
        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_metadata.placement)

            if current_rank == rank:
                local_shard_metadatas.append(shard_metadata)

        if len(local_shards) != len(local_shard_metadatas):
            raise RuntimeError(
                f'Number of local shards ({len(local_shards)}) does not match number of local '
                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
                f'on rank ({current_rank}) ')

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_meta.placement)

            # validate if shard_meta in the metadatas collected from sharded_tensor_metadata
            assert shard_meta in local_shard_metadatas, \
                "local shard metadata not in sharded_tensor_metadata!"

            _raise_if_mismatch(tensor_properties.layout,
                               local_shard_tensor.layout, "layout",
                               current_rank, True)
            if not local_shard_tensor.is_contiguous():
                raise ValueError(
                    'Only torch.contiguous_format memory_format is currently supported'
                )

            _raise_if_mismatch(shard_meta.shard_sizes,
                               list(local_shard_tensor.size()), "size",
                               current_rank)
            _raise_if_mismatch(tensor_properties.pin_memory,
                               local_shard_tensor.is_pinned(), "pin_memory",
                               current_rank, True)
            _raise_if_mismatch(local_device, local_shard_tensor.device,
                               "device", current_rank)
            _raise_if_mismatch(tensor_properties.dtype,
                               local_shard_tensor.dtype, "dtype", current_rank,
                               True)
            _raise_if_mismatch(tensor_properties.requires_grad,
                               local_shard_tensor.requires_grad,
                               "requires_grad", current_rank, True)

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor._local_shards = local_shards
        sharded_tensor._sharding_spec = _infer_sharding_spec_from_shards_metadata(
            shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Exemple #5
0
    def _init_from_local_shards_and_global_metadata(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        sharding_spec=None,
    ) -> "ShardedTensor":
        """
        Initialize a ShardedTensorBase with local shards and a global
        ShardedTensorMetadata built on each rank.
        Warning: This API is experimental and subject to change. It does
                 not do cross rank validations, and fully rely on the user
                 for the correctness of sharded_tensor_metadata on each rank
        """
        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError(
                "Only torch.strided layout is currently supported")

        if sharding_spec is None:
            spec = shard_spec._infer_sharding_spec_from_shards_metadata(
                shards_metadata)
        else:
            spec = sharding_spec

        sharded_tensor_base = ShardedTensor.__new__(
            ShardedTensor,
            spec,
            sharded_tensor_metadata.size,
            dtype=tensor_properties.dtype,
            layout=tensor_properties.layout,
            pin_memory=tensor_properties.pin_memory,
            requires_grad=tensor_properties.requires_grad,
        )

        def _raise_if_mismatch(expected,
                               actual,
                               prop_name,
                               rank,
                               is_property=False):
            tensor_property_or_metadata = ("tensor property" if is_property
                                           else "local ShardMetadata")
            if expected != actual:
                raise ValueError(
                    f"Local shards' tensor {prop_name} property is incompatible with "
                    f"{tensor_property_or_metadata} on rank {rank}: "
                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
                    f"local shard tensor {prop_name}={actual}.")

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            placement = shard_meta.placement
            assert placement is not None, "Must specify placement for `Shard`!"
            rank = placement.rank()
            local_device = placement.device()

            _raise_if_mismatch(
                tensor_properties.layout,
                local_shard_tensor.layout,
                "layout",
                rank,
                True,
            )
            if not local_shard_tensor.is_contiguous():
                raise ValueError(
                    "Only torch.contiguous_format memory_format is currently supported"
                )

            _raise_if_mismatch(
                shard_meta.shard_sizes,
                list(local_shard_tensor.size()),
                "size",
                rank,
            )
            _raise_if_mismatch(
                tensor_properties.pin_memory,
                local_shard_tensor.is_pinned(),
                "pin_memory",
                rank,
                True,
            )
            _raise_if_mismatch(local_device, local_shard_tensor.device,
                               "device", rank)
            _raise_if_mismatch(
                tensor_properties.dtype,
                local_shard_tensor.dtype,
                "dtype",
                rank,
                True,
            )
            _raise_if_mismatch(
                tensor_properties.requires_grad,
                local_shard_tensor.requires_grad,
                "requires_grad",
                rank,
                True,
            )

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor_base._local_shards = local_shards
        return sharded_tensor_base