Ejemplo n.º 1
0
    def _init_enumerable(self, dims, tensor_init_params: TensorInitParams):
        # Validate the sharding spec is compatible with the tensor.
        check_tensor(self._sharding_spec.shards,
                     dims)  # type: ignore[attr-defined]

        current_rank = dist.get_rank(self._process_group)

        shards_metadata = []
        for shard_metadata in self._sharding_spec.shards:  # type: ignore[attr-defined]
            rank, local_device = _parse_and_validate_remote_device(
                self._process_group, shard_metadata.placement)
            shards_metadata.append(shard_metadata)

            if current_rank == rank:
                # Initialize the local shard.
                local_shard = _create_tensor_from_params(
                    *shard_metadata.shard_sizes,
                    local_device=local_device,
                    tensor_init_params=tensor_init_params)
                self._local_shards.append(Shard(local_shard, shard_metadata))

        # Build overall metadata
        self._metadata = ShardedTensorMetadata(
            shards_metadata,
            dims,
            tensor_init_params.tensor_properties,
        )
Ejemplo n.º 2
0
def build_global_metadata(
        gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]):
    global_sharded_tensor_metadata = None
    global_metadata_rank = 0

    for rank, rank_metadata in enumerate(gathered_metadatas):
        if rank_metadata is None:
            continue

        if global_sharded_tensor_metadata is None:
            global_sharded_tensor_metadata = rank_metadata
            global_metadata_rank = rank
        else:
            _raise_if_mismatch(global_sharded_tensor_metadata.size,
                               rank_metadata.size,
                               "global_size", [global_metadata_rank, rank],
                               is_local=False)

            # don't need to check layout and memory format as we already checked in local shards validation stage
            _raise_if_mismatch(
                global_sharded_tensor_metadata.tensor_properties.dtype,
                rank_metadata.tensor_properties.dtype,
                "dtype", [global_metadata_rank, rank],
                is_local=False)

            _raise_if_mismatch(
                global_sharded_tensor_metadata.tensor_properties.requires_grad,
                rank_metadata.tensor_properties.requires_grad,
                "requires_grad", [global_metadata_rank, rank],
                is_local=False)

            _raise_if_mismatch(
                global_sharded_tensor_metadata.tensor_properties.pin_memory,
                rank_metadata.tensor_properties.pin_memory,
                "pin_memory", [global_metadata_rank, rank],
                is_local=False)
            # pass all validations, extend shards metadata
            global_sharded_tensor_metadata.shards_metadata.extend(
                rank_metadata.shards_metadata)

    if global_sharded_tensor_metadata is not None:
        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(
            global_sharded_tensor_metadata.shards_metadata)

        # check if the shards_metadata is compatible with global size of the sharded tensor.
        check_tensor(global_sharded_tensor_metadata.shards_metadata,
                     global_sharded_tensor_metadata.size)
    else:
        raise ValueError("ShardedTensor have no local shards on all ranks!")

    return global_sharded_tensor_metadata
Ejemplo n.º 3
0
Archivo: api.py Proyecto: vors/pytorch
    def _init_enumerable(
        self,
        dims,
        dtype,
        layout,
        requires_grad,
        pin_memory,
        memory_format,
    ):
        # Validate the sharding spec is compatible with the tensor.
        check_tensor(self._sharding_spec.shards,
                     dims)  # type: ignore[attr-defined]

        current_rank = dist.get_rank(self._process_group)

        shards_metadata = []
        for shard_metadata in self._sharding_spec.shards:  # type: ignore[attr-defined]
            rank, local_device = self._parse_and_validate_remote_device(
                shard_metadata.placement)
            shards_metadata.append(shard_metadata)

            if current_rank == rank:
                # Initialize the local shard.
                local_shard = torch.empty(
                    *shard_metadata.shard_lengths,
                    dtype=dtype,
                    layout=layout,
                    device=local_device,
                    requires_grad=requires_grad,
                    memory_format=memory_format,
                    pin_memory=pin_memory,
                )

                self._local_shards.append(Shard(local_shard, shard_metadata))

        # Build overall metadata
        self._metadata = ShardedTensorMetadata(
            shards_metadata,
            dims,
            dtype,
            layout,
            requires_grad,
            memory_format,
            pin_memory,
        )
Ejemplo n.º 4
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        process_group=None,
        init_rrefs=False,
    ):
        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError('Only torch.strided layout is currently supported')

        sharded_tensor = cls.__new__(cls)

        # prepare initialization
        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)

        sharded_tensor._metadata = sharded_tensor_metadata

        # no sharding spec for sharded tensors that initialized
        # from this API.
        sharded_tensor._sharding_spec = None

        current_rank = dist.get_rank(sharded_tensor._process_group)

        local_shard_metadatas = []

        # collect local shard metadatas from the global sharded_tensor_metadata
        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
            rank, local_device = sharded_tensor._parse_and_validate_remote_device(shard_metadata.placement)

            if current_rank == rank:
                local_shard_metadatas.append(shard_metadata)

        if len(local_shards) != len(local_shard_metadatas):
            raise RuntimeError(
                f'Number of local shards ({len(local_shards)}) does not match number of local '
                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
                f'on rank ({current_rank}) '
            )

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            rank, local_device = sharded_tensor._parse_and_validate_remote_device(shard_meta.placement)

            # validate if shard_meta in the metadatas collected from sharded_tensor_metadata
            assert shard_meta in local_shard_metadatas, \
                "local shard metadata not in sharded_tensor_metadata!"

            if local_shard_tensor.layout != tensor_properties.layout:
                raise ValueError(
                    f'Local shard tensor layout does not match with tensor_properties! '
                    f'local shard tensor layout: {local_shard_tensor.dtype}, '
                    f'tensor_properties layout: {tensor_properties.layout}'
                )

            if not local_shard_tensor.is_contiguous():
                raise ValueError('Only torch.contiguous_format memory_format is currently supported')

            if shard_meta.shard_lengths != list(local_shard_tensor.size()):
                raise ValueError(
                    f'Local shard tensor is incompatible with local ShardMetadata! '
                    f'local shard tensor size: {local_shard_tensor.size()}, '
                    f'local ShardMetadata shard lengths: {shard_meta.shard_lengths}'
                )

            if local_shard_tensor.is_pinned() != tensor_properties.pin_memory:
                raise ValueError(
                    f'Local shard tensor pin_memory does not match with tensor_properties! '
                    f'local shard tensor pin_memory: {local_shard_tensor.is_pinned()}, '
                    f'tensor_properties pin_memory: {tensor_properties.pin_memory}'
                )

            if local_shard_tensor.device != local_device:
                raise ValueError(
                    f'Local shard tensor device does not match with local Shard placement! '
                    f'local shard tensor device: {local_shard_tensor.device}, '
                    f'local shard metadata placement device: {local_device}'
                )

            if local_shard_tensor.dtype != tensor_properties.dtype:
                raise ValueError(
                    f'Local shard tensor dtype does not match with tensor_properties! '
                    f'local shard tensor dtype: {local_shard_tensor.dtype}, '
                    f'tensor_properties dtype: {tensor_properties.dtype}'
                )

            if local_shard_tensor.requires_grad != tensor_properties.requires_grad:
                raise ValueError(
                    f'Local shard tensor requires_grad does not match with tensor_properties! '
                    f'local shard tensor requires_grad: {local_shard_tensor.requires_grad}, '
                    f'tensor_properties requires_grad: {tensor_properties.requires_grad}'
                )

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor._local_shards = local_shards

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Ejemplo n.º 5
0
    def test_enumerable_sharding_spec(self):
        # test valid specs

        # test row-wise sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])
        check_tensor(spec.shards, torch.rand(10, 5).size())

        # test row and column sharding
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[3, 3],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 3],
                shard_lengths=[3, 3],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[3, 0],
                shard_lengths=[3, 3],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[3, 3],
                shard_lengths=[3, 3],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test uneven shard sizes.
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[2, 4],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 4],
                shard_lengths=[4, 2],
                placement="cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[2, 0],
                shard_lengths=[4, 4],
                placement="cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[4, 4],
                shard_lengths=[2, 2],
                placement="cuda:3",
            ),
        ])
        check_tensor(spec.shards, torch.rand(6, 6).size())

        # test invalid sharding
        with self.assertRaisesRegex(ValueError,
                                    'Could not parse remote_device'):
            ShardMetadata(shard_offsets=[0],
                          shard_lengths=[1],
                          placement="cuda:foo")

        with self.assertRaisesRegex(ValueError, 'same number of elements'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_lengths=[1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
            ShardMetadata(shard_offsets=[-1, 0],
                          shard_lengths=[1, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'):
            ShardMetadata(shard_offsets=[0, 0],
                          shard_lengths=[0, 1],
                          placement="cuda:0")

        with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
            EnumerableShardingSpec([])

        with self.assertRaisesRegex(ValueError,
                                    'Found inconsistent ranks for shards'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_lengths=[1, 1],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[0, 0, 0],
                              shard_lengths=[1, 1, 1],
                              placement="cpu"),
            ])

        with self.assertRaisesRegex(ValueError, 'Shards.*overlap'):
            EnumerableShardingSpec([
                ShardMetadata(shard_offsets=[0, 0],
                              shard_lengths=[3, 3],
                              placement="cpu"),
                ShardMetadata(shard_offsets=[2, 0],
                              shard_lengths=[3, 3],
                              placement="cpu"),
            ])

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'Rank of tensor is.*but shards rank'):
            check_tensor(spec.shards, torch.rand(10, 10, 10).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'):
            check_tensor(spec.shards, torch.rand(10, 3).size())

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="cuda:1",
            )
        ])

        with self.assertRaisesRegex(ValueError,
                                    'does not match tensor volume'):
            check_tensor(spec.shards, torch.rand(10, 10).size())
Ejemplo n.º 6
0
    def _init_from_local_shards_and_global_metadata(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        process_group=None,
        init_rrefs=False,
    ):
        """
        Initialize a ShardedTensor with local shards and a global
        ShardedTensorMetadata built on each rank.

        Warning: This API is experimental and subject to change. It does
                 not do cross rank validations, and fully rely on the user
                 for the correctness of sharded_tensor_metadata on each rank
        """
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)

        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError(
                'Only torch.strided layout is currently supported')

        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        sharded_tensor._metadata = sharded_tensor_metadata

        local_shard_metadatas = []

        def _raise_if_mismatch(expected,
                               actual,
                               prop_name,
                               rank,
                               is_property=False):
            tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata"
            if expected != actual:
                raise ValueError(
                    f"Local shards' tensor {prop_name} property is incompatible with "
                    f"{tensor_property_or_metadata} on rank {rank}: "
                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
                    f"local shard tensor {prop_name}={actual}.")

        # collect local shard metadatas from the global sharded_tensor_metadata
        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_metadata.placement)

            if current_rank == rank:
                local_shard_metadatas.append(shard_metadata)

        if len(local_shards) != len(local_shard_metadatas):
            raise RuntimeError(
                f'Number of local shards ({len(local_shards)}) does not match number of local '
                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
                f'on rank ({current_rank}) ')

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_meta.placement)

            # validate if shard_meta in the metadatas collected from sharded_tensor_metadata
            assert shard_meta in local_shard_metadatas, \
                "local shard metadata not in sharded_tensor_metadata!"

            _raise_if_mismatch(tensor_properties.layout,
                               local_shard_tensor.layout, "layout",
                               current_rank, True)
            if not local_shard_tensor.is_contiguous():
                raise ValueError(
                    'Only torch.contiguous_format memory_format is currently supported'
                )

            _raise_if_mismatch(shard_meta.shard_sizes,
                               list(local_shard_tensor.size()), "size",
                               current_rank)
            _raise_if_mismatch(tensor_properties.pin_memory,
                               local_shard_tensor.is_pinned(), "pin_memory",
                               current_rank, True)
            _raise_if_mismatch(local_device, local_shard_tensor.device,
                               "device", current_rank)
            _raise_if_mismatch(tensor_properties.dtype,
                               local_shard_tensor.dtype, "dtype", current_rank,
                               True)
            _raise_if_mismatch(tensor_properties.requires_grad,
                               local_shard_tensor.requires_grad,
                               "requires_grad", current_rank, True)

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor