Ejemplo n.º 1
0
    def shard(self,
              tensor: torch.Tensor,
              src_rank: int = 0,
              process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (ShardedTensor)
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned())
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []

        current_rank = dist.get_rank(process_group)
        # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
        dist.broadcast(tensor, src=src_rank, group=process_group)

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(
                process_group, shard_meta.placement)
            if rank == current_rank:
                shard_offsets = shard_meta.shard_offsets
                shard_sizes = shard_meta.shard_sizes
                local_tensor = tensor
                for idx, (offset,
                          size) in enumerate(zip(shard_offsets, shard_sizes)):
                    if size < tensor.size(idx):
                        # Reshape to get shard for this rank and we don't want autograd
                        # recording here for the narrow op and 'local_shard' should be a
                        # leaf variable in the autograd graph.
                        local_tensor = local_tensor.narrow(
                            idx, shard_offsets[idx],
                            shard_sizes[idx]).clone().detach().contiguous()
                # Sync requires_grad to local_shard.
                local_tensor.requires_grad = tensor.requires_grad
                local_shards.append(
                    Shard(tensor=local_tensor, metadata=shard_meta))

        st = ShardedTensor._init_from_local_shards(local_shards,
                                                   tensor.size(),
                                                   process_group=process_group)
        # Manually set sharding_spec
        st._sharding_spec = self

        return st
Ejemplo n.º 2
0
def _create_shard_for(tensor: Tensor) -> Shard:
    return Shard(
        tensor=tensor,
        metadata=_create_shard_metadata(tensor.size()),
    )
Ejemplo n.º 3
0
    def shard(self,
              tensor: torch.Tensor,
              src_rank: int = 0,
              process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (ShardedTensor)
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned())
        current_rank = dist.get_rank(process_group)
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []
        local_tensor = None
        local_metadata = None
        tensors_to_scatter = [None] * dist.get_world_size(process_group)

        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
        chunks = len(self.placements)
        split_size = get_split_size(sharding_dim_size, chunks)
        scatter_shape = list(tensor.size())
        scatter_shape[self.dim] = split_size  # type: ignore[index]

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(
                process_group, shard_meta.placement)
            if current_rank == src_rank:
                # Reshape to get shard for this rank and we don't want autograd
                # recording here for the narrow op and 'local_shard' should be a
                # leaf variable in the autograd graph.
                narrowed_tensor = narrow_tensor(tensor, shard_meta)
                if shard_meta.shard_sizes[
                        self.dim] < split_size:  # type: ignore[index]
                    # for the last shard that might be smaller to other shards
                    # resize the narrowed tensor to the same size and use it for
                    # the scatter collective as dist.scatter requires same size
                    # inputs on every rank
                    tensor_to_scatter = narrowed_tensor.detach().clone(
                    ).resize_(scatter_shape)
                else:
                    tensor_to_scatter = narrowed_tensor.detach().clone(
                    ).contiguous()

                tensors_to_scatter[rank] = tensor_to_scatter

            if current_rank == rank:
                local_tensor = torch.empty(scatter_shape,
                                           dtype=tensor.dtype,
                                           layout=tensor.layout,
                                           device=device)
                local_metadata = shard_meta

        # each rank should have local_tensor and local_metadata initialized if we build
        # the metadata list in a correct way.
        assert local_tensor is not None
        assert local_metadata is not None

        # Scatter the shards to all ranks in the pg
        dist.scatter(local_tensor,
                     scatter_list=tensors_to_scatter
                     if current_rank == src_rank else None,
                     src=src_rank,
                     group=process_group)

        if list(local_tensor.size()) != local_metadata.shard_sizes:
            # detach again after receiving to ensure local shards remain a leaf node
            local_tensor = local_tensor.resize_(
                local_metadata.shard_sizes).detach()

        # Sync requires_grad to local_shard.
        local_tensor.requires_grad = tensor.requires_grad

        local_shards.append(Shard(tensor=local_tensor,
                                  metadata=local_metadata))

        st = ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards, tensor_meta, process_group=process_group)

        # Manually set sharding_spec
        st._sharding_spec = self

        return st
Ejemplo n.º 4
0
    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (
            ShardedTensor
        )
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned()
        )
        current_rank = dist.get_rank(process_group)
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []
        local_tensor = None
        local_metadata = None
        tensors_to_scatter = []

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
            shard_offsets = shard_meta.shard_offsets
            shard_sizes = shard_meta.shard_sizes
            if current_rank == src_rank:
                narrowed_tensor = tensor
                for idx, (offset, size) in enumerate(zip(shard_offsets, shard_sizes)):
                    if size < tensor.size(idx):
                        # Reshape to get shard for this rank and we don't want autograd
                        # recording here for the narrow op and 'local_shard' should be a
                        # leaf variable in the autograd graph.
                        narrowed_tensor = narrowed_tensor.narrow(
                            idx,
                            shard_offsets[idx],
                            shard_sizes[idx]
                        ).clone().detach().contiguous()
                tensors_to_scatter.append(narrowed_tensor)

            if current_rank == rank:
                local_tensor = torch.empty(
                    shard_sizes, dtype=tensor.dtype, layout=tensor.layout, device=device)
                local_metadata = shard_meta

        # Scatter the shards to all ranks in the pg
        dist.scatter(
            local_tensor,
            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
            src=src_rank,
            group=process_group
        )

        assert local_tensor is not None
        assert local_metadata is not None
        # Sync requires_grad to local_shard.
        local_tensor.requires_grad = tensor.requires_grad

        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))

        st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=process_group)
        # Manually set sharding_spec
        st._sharding_spec = self

        return st