def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg): """ For col-wise sharding of weight, we need to distribute results to each rank. We do them in this function. Note that, if the index in the Sharding Spec is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). Args: results: results from ops applied to inputs from all ranks. We need to distribute them back to their original ranks. input: tensor to be applied op to. world_size: number of ranks. weight: shareded weight tensor. pg: process group. Return: column rearranged result. """ # Process results and outputs for all2all. sharding_dim = weight._sharding_spec.dim sharding_dim_size = weight.size(sharding_dim) dims = list(results[0].size()) dims[0] = sharding_dim_size combined_results = torch.cat(results) output = torch.empty(*dims, device=combined_results.device, dtype=combined_results.dtype) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [0] * world_size for idx, placement in enumerate(weight._sharding_spec.placements): output_split_sizes[placement.rank()] = get_chunked_dim_size( sharding_dim_size, split_size, idx) # distribute the outputs using all2all. output = all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange columns appropriately for output. rearrange_columns = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if not rearrange_columns: return output indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) return output.index_select(0, torch.tensor(indices, device=output.device))
def generate_local_weight_sharding_params_for_test( local_weight, sharded_dim, gpu_num, spec, rank ): """ Shard the local weight based the given spec, so we can compare against the one from sharded tensor. Args: local_weight: weight matrix to be sharded. sharded_dim: The dimension which we shard on. gpu_num: number of ranks. spec: shareding spec. rank: # of cuda process. Returns: start_pos: start position of sharded weight on the given rank. chunk_size: chunk size of sharded weight on the given rank. """ sharding_dim_size = local_weight.size(sharded_dim) split_size = get_split_size(sharding_dim_size, gpu_num) current_offsets = 0 start_pos = current_offsets for idx, placement in enumerate(spec.placements): chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) if rank == placement.rank(): start_pos = current_offsets break current_offsets += chunk_size return start_pos, chunk_size
def _init_chunked( self, dims, tensor_init_params: TensorInitParams, ): current_rank = dist.get_rank(self._process_group) sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined] # Validate the sharding spec. if not isinstance(sharding_dim, int): raise ValueError( f"Sharding dim needs to be an integer, found: {sharding_dim}") if sharding_dim >= len(dims) or sharding_dim < -len(dims): raise ValueError(f"Invalid sharding dim: {sharding_dim}") dim_size = dims[sharding_dim] remote_devices = self._sharding_spec.placements # type: ignore[attr-defined] chunks = len(remote_devices) # split_size computed similar to 'torch.chunk' split_size = get_split_size(dim_size, chunks) shards_metadata = [] for idx, remote_device in enumerate(remote_devices): rank, local_device = _parse_and_validate_remote_device( self._process_group, remote_device) # Adjust the sharding dim for this rank. sharded_dim_size = get_chunked_dim_size(dim_size, split_size, idx) if sharded_dim_size > 0: # Build sharding_metadata. # deepcopy for modification. rank_dims = dims.copy() rank_offsets = [0] * len(dims) rank_offsets[sharding_dim] = split_size * idx rank_dims[sharding_dim] = sharded_dim_size shard_metadata = ShardMetadata(rank_offsets, rank_dims, remote_device) shards_metadata.append(shard_metadata) # Build the local shard for the current rank if it is involved in the sharding spec. if current_rank == rank: # Initialize the local shard. local_shard = _create_tensor_from_params( *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append( Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, tensor_init_params.tensor_properties, )
def _infer_chunk_sharding_spec_case(self, placements, sharding_dim, st_size): world_size = len(placements) split_size = get_split_size(st_size[sharding_dim], world_size) shards_metadata = [None] * world_size for idx, placement in enumerate(placements): shard_size = copy.deepcopy(st_size) offsets = [0] * len(st_size) offsets[sharding_dim] = split_size * idx shard_size[sharding_dim] = get_chunked_dim_size(st_size[sharding_dim], split_size, idx) shards_metadata[placement.rank()] = ShardMetadata( shard_offsets=offsets, shard_sizes=shard_size, placement=placement, ) spec = _infer_sharding_spec_from_shards_metadata(shards_metadata) self.assertTrue(isinstance(spec, ChunkShardingSpec)) self.assertEqual(spec.dim, sharding_dim) self.assertEqual(spec.placements, placements)
def build_reshard_metadata( st_size: torch.Size, sharding_spec: ShardingSpec, world_size: int, ) -> Tuple[List[ShardMetadata], List[int]]: """ Based the given sharding spec, we calculate the offset and local shard size. We then build a ShardMetadata on top of the calculation result. Args: st_size (torch.Size): The size of the sharded tensor. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor is sharded. world_size (int): number of ranks. Returns: A Tuple of the followings: A List[`ShardMetadata`] which contains the metadata for the shard, including offsets, lengths and device placement. A List[int] which contains the ranks in the order of placement. """ shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] shards_metadata = [None] * world_size ranks = [] offsets = [0] * len(st_size) split_size = get_split_size(st_size[shard_dim], world_size) for idx, placement in enumerate( sharding_spec.placements): # type: ignore[attr-defined] ranks.append(placement.rank()) sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) local_tensor_size = list(st_size) local_tensor_size[shard_dim] = sharded_dim_size shards_metadata[ placement.rank()] = ShardMetadata( # type: ignore[call-overload] shard_offsets=copy.deepcopy(offsets), shard_sizes=local_tensor_size, placement=placement, ) offsets[shard_dim] += sharded_dim_size return shards_metadata, ranks # type: ignore[return-value]
def _init_from_local_tensor( cls, local_tensor: torch.Tensor, sharding_spec: ShardingSpec, *global_size: Sequence[int], process_group: dist.ProcessGroup = None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor given only one local tensor, global sharded tensor size and sharding spec on each rank. Args: local_tensor (Tensor): Single tensor of local shard stored in each rank. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. global_size (Sequence[int]): Size of the sharded tensor. process_group (ProcessGroup, optional): The process group to aggregate on. Default: None init_rrefs (bool, optional): Whether or not to initialize :class:`torch.distributed.rpc.RRef`s pointing to remote shards. Need to initialize the RPC Framework if specified as ``True``. Default: ``False``. Returns: A :class:`ShardedTensor` sharded based on the given sharding_spec with local tensor stored in the current rank. Examples: >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) >>> local_tensor tensor([[1, 2, 3, 4]]) # Rank 0 tensor([[3, 4, 5, 6]]) # Rank 1 >>> sharding_dim = 0 >>> sharding_spec = ChunkShardingSpec( dim=sharding_dim, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) >>> st ShardedTensor( ShardedTensorMetadata( shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), ], size=torch.Size([2, 4]) ) >>> st.local_tensor() tensor([1, 2, 3, 4]) # Rank 0 tensor([3, 4, 5, 6]) # Rank 1 Warning: This API is experimental and subject to change. It lacks of a fully across rank validations, and we only validate the local shard on the current rank. We fully rely on the user to ensure local tensor is sharded based on the sharding spec. """ if not isinstance(sharding_spec, ChunkShardingSpec): raise NotImplementedError('Only ChunkShardingSpec is supported.') if not local_tensor.is_contiguous(): raise ValueError('local_tensor is not a contiguous Tensor.') process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) global_tensor_size = _flatten_tensor_size(global_size) sharding_dim = sharding_spec.dim split_size = get_split_size(global_tensor_size[sharding_dim], world_size) # type: ignore[index] current_offsets = [0] * len(global_tensor_size) gathered_metadatas = [None] * world_size local_shards = [] for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size( global_tensor_size[sharding_dim], split_size, idx # type: ignore[index] ) shard_size = copy.deepcopy(global_tensor_size) shard_size[ sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) if current_rank == placement.rank(): # type: ignore[union-attr] local_shard = local_tensor else: local_shard = torch.empty( shard_size, device=placement.device(), # type: ignore[union-attr] requires_grad=local_tensor.requires_grad, ) shards = [ Shard( tensor=local_shard, metadata=shard_metadata, # type: ignore[arg-type] ) ] if current_rank == placement.rank(): # type: ignore[union-attr] local_shards = shards gathered_metadatas[placement.rank( )] = build_metadata_from_local_shards( # type: ignore[call-overload, union-attr] shards, global_tensor_size, placement.rank(), process_group # type: ignore[union-attr, arg-type] ) current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] global_sharded_tensor_metadata = build_global_metadata( gathered_metadatas) return ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, global_sharded_tensor_metadata, process_group=process_group, init_rrefs=init_rrefs, )
def reshuffle_local_shard( local_shard: torch.Tensor, st_size: torch.Size, sharding_spec: ShardingSpec, resharding_spec: ShardingSpec, pg: ProcessGroup, ) -> Tuple[List[Shard], List[ShardMetadata]]: """ Reshuffle the local shard directly when the reshard dim is same as the original sharding dim. Logically we do this in two step: 1. To collect all shards based on original sharding spec. 2. Reshard the tensor based on the given resharding spec. In reality, we consolidate the two steps into one by sending the local tensor to the new shard directly based on the resharding spec. Args: local_tensor (Tensor): Local tensor stored in the current rank. st_size (torch.Size): The size of the sharded tensor. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor is sharded originally. resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor will be resharded. pg (ProcessGroup): The process group to aggregate on. Returns: A Tuple of the followings: A List[`Shard`] which contains the local tensor and its metadata. A List[`ShardMetadata`] which contains the metadata for the shard, including offsets, lengths and device placement. """ current_rank = dist.get_rank(pg) world_size = dist.get_world_size(pg) # Build shards_metadata first. shards_metadata, ranks = build_reshard_metadata(st_size, resharding_spec, world_size) # Get input split size for all2all. reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] split_size = get_split_size(st_size[reshard_dim], world_size) input_split_sizes = [0] * world_size idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] new_rank = resharding_spec.placements[idx].rank( ) # type: ignore[union-attr, attr-defined] input_split_sizes[new_rank] = local_shard.size(reshard_dim) # Get output split size for all2all. output_split_sizes = [0] * world_size new_idx = ranks.index(current_rank) sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) output_split_sizes[new_rank] = sharded_dim_size # Get gathered_input for all2all. local_shard = local_shard.transpose(0, reshard_dim).contiguous() gathered_input_size = list(local_shard.size()) gathered_input_size[0] = sharded_dim_size gathered_input = torch.empty(gathered_input_size, device=local_shard.device) # all2all. local_shard = all_to_all_single( gathered_input, local_shard, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=pg, ) local_tensor = local_shard.transpose(0, reshard_dim).contiguous() local_shards = [Shard(local_tensor, shards_metadata[current_rank])] return local_shards, shards_metadata
def reshard_local_shard( local_tensor: torch.Tensor, st_size: torch.Size, sharding_spec: ShardingSpec, resharding_spec: ShardingSpec, pg: ProcessGroup, ) -> Tuple[List[Shard], List[ShardMetadata]]: """ Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is different from the original sharding dim, we need to do two steps logically: 1. To collect all shards based on original sharding spec. 2. Reshard the tensor based on the given resharding spec. In reality, we consolidate the two steps into one by sending each rank the new shard based on the resharding spec. Args: local_tensor (Tensor): Local tensor stored in the current rank. st_size (torch.Size): The size of the sharded tensor. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor is sharded originally. resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor will be resharded. pg (ProcessGroup): The process group to aggregate on. Returns: A Tuple of the followings: A List[`Shard`] which contains the local tensor and its metadata. A List[`ShardMetadata`] which contains the metadata for the shard, including offsets, lengths and device placement. """ current_rank = dist.get_rank(pg) world_size = dist.get_world_size(pg) current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] # Build shards_metadata first. shards_metadata, ranks = build_reshard_metadata(st_size, resharding_spec, world_size) # Compute expected size input_split_sizes = [] for metadata in shards_metadata: input_split_sizes.append(metadata.shard_sizes[reshard_dim]) rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) if rearrange_input: # Need to re-arrange reshard_dim of local_tensor before all2all. indices: List[int] = [] for metadata in shards_metadata: offset_start_idx = metadata.shard_offsets[reshard_dim] split_size = metadata.shard_sizes[reshard_dim] indices += range(offset_start_idx, offset_start_idx + split_size) local_tensor = local_tensor.index_select( reshard_dim, torch.tensor(indices, device=local_tensor.device)) # Because reshard_dim != original shard_dim. We need to compute the # size of tensor from each rank. output_tensor_list = [torch.tensor(1)] * world_size split_size = get_split_size(st_size[current_sharding_dim], world_size) rearrange_output_list = False indices = [] for idx, placement in enumerate( sharding_spec.placements): # type: ignore[attr-defined] sharded_dim_size = get_chunked_dim_size(st_size[current_sharding_dim], split_size, idx) output_tensor_size = list(st_size) output_tensor_size[current_sharding_dim] = sharded_dim_size output_tensor_size[reshard_dim] = input_split_sizes[current_rank] output_tensor_list[ placement.rank()] = torch.empty( # type: ignore[union-attr, index] output_tensor_size, device=local_tensor.device) indices.append( placement.rank()) # type: ignore[union-attr, index, arg-type] if idx != placement.rank(): # type: ignore[union-attr] rearrange_output_list = True # Perform autograd enabled all2all. input_tensor_list = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) input_tensor_list = [tensor.contiguous() for tensor in input_tensor_list] output_tensor_list = all_to_all( output_tensor_list, input_tensor_list, group=pg, ) if rearrange_output_list: # Need to re-arrange original shard_dim of output_tensor_list. output_tensor_list = [output_tensor_list[idx] for idx in indices ] # type: ignore[call-overload] local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) local_shards = [Shard(local_tensor, shards_metadata[current_rank])] return local_shards, shards_metadata
def _handle_row_wise_sharding_tensor(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ # alltoall to gather all the appropriate inputs. input_t = input.transpose(0, -1).contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input_size = [input_split_sizes[rank] * world_size] + list( input_t_size[1:]) gathered_input = torch.empty(gathered_input_size, device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.transpose(0, -1) # Perform local matmuls for all shards results = [] shard_size = local_shard_t.size()[0] for r in range(world_size): inp = torch.narrow(gathered_input, -1, r * shard_size, shard_size) results.append( inp.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)) # Return the partial local result. return _PartialTensor(torch.cat(results), pg)
def _handle_row_wise_lookup_distribute(input_sorted, input, world_size, weight, rank, padding_idx): """ In the circumstance of row-wise sharding of weight, we need to distribute the sorted lookup IDs of embedding/embeddingBag to each rank. If the index in the placement is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). In addition, we do two things for padding_idx. The first thing is to only set it if it's within the range of the current rank and the other thing is to do the modularization of it by sharded_dim_size_max. Args: input_sorted: sorted lookup IDs of embedding/embeddingBag. input: tensor to be applied op on. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient and reduction. Return: input_sorted: sorted lookup IDs of embedding/embeddingBag Rearrangement performed if it is needed. input_split_sizes: size of IDs to be assigned to each rank. sharded_dim_size_max: the max size of the row each rank gets. input_split_rearrange_indices: indices of row rearrangement. rearrange_indices_1d_second_order: reverse indices of row rearrangement, which will be used to restore the original order. padding_idx: Same as input if padding_idx is within the range of the given rank; otherwise, None is returned. It is also modularized by sharded_dim_size_max. """ # Decide which rank the input goes to by check the sharding range. split_size = get_split_size(weight.size(0), world_size) rearrange_rows = False indices_flatten = None input_split_sizes: List[int] = [0] * world_size input_split_start_indices: List[int] = [0] * world_size start_row_idx_rank = None end_row_idx_rank = None # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0) for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx) start_row_idx = idx * sharded_dim_size_max end_row_idx = start_row_idx + sharded_dim_size start_idx = torch.searchsorted(input_sorted, start_row_idx).item() end_idx = torch.searchsorted(input_sorted, end_row_idx).item() input_split_sizes[placement.rank()] = int(end_idx - start_idx) input_split_start_indices[placement.rank()] = int(start_idx) if placement.rank() != idx: rearrange_rows = True # Store the range of the current rank. if placement.rank() == rank: start_row_idx_rank = start_row_idx end_row_idx_rank = end_row_idx # Perform the modular if padding_idx is within the range. if padding_idx is not None: if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank: padding_idx = None else: padding_idx = padding_idx % sharded_dim_size_max rearrange_indices_1d_second_order = None if rearrange_rows: # Need to re-arrange the 1D tensor to be sent via all2all. indices: List[List[int]] = [[0]] * world_size for placement in weight._sharding_spec.placements: split_length = input_split_sizes[placement.rank()] offset_idx = input_split_start_indices[placement.rank()] indices[placement.rank()] = list( range(offset_idx, offset_idx + split_length)) indices_flatten = list(idx for indice in indices for idx in indice) input_sorted = input_sorted.index_select( 0, torch.tensor(indices_flatten, device=input.device)) rearrange_indices_1d_second_order = torch.argsort( torch.Tensor(indices_flatten)) return ( input_sorted, input_split_sizes, sharded_dim_size_max, torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None, rearrange_indices_1d_second_order, padding_idx, )
def test_get_chunked_dim_size(self): self.assertEqual(3, get_chunked_dim_size(11, 3, 0)) self.assertEqual(2, get_chunked_dim_size(11, 3, 3)) self.assertEqual(4, get_chunked_dim_size(13, 4, 0)) self.assertEqual(1, get_chunked_dim_size(13, 4, 3)) self.assertEqual(0, get_chunked_dim_size(5, 2, 3))
def _shard_tensor(tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.Tensor`, it shards that tensor according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. Args: tensor (:class:`torch.Tensor`): Tensor needs to be sharded. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Returns: A :class:`ShardedTensor` sharded from the given tensor. .. warning:: Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is currently supported as the ``sharding_spec``. """ if not isinstance(sharding_spec, ChunkShardingSpec): raise NotImplementedError('Only ChunkShardingspec is supported.') if not tensor.is_contiguous(): raise ValueError('input tensor is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_sizes[ sharding_spec.dim], # type: ignore[union-attr, index] ).clone().detach().contiguous() # Sync requires_grad to local_shard. local_shard.requires_grad = tensor.requires_grad # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] return ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
def shard_parameter(module: torch.nn.Module, param_name: str, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. This method replaces ``module.param_name`` with a :class:`torch.distributed._shard.sharded_tensor.ShardedTensor` Args: module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. param_name (str): Name of the parameter of ``module`` that needs to be sharded. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. .. warning:: Only :class:`torch.distributed._shard.sharding_spec.ShardingSpec` is currently supported as the ``sharding_spec``. """ # Perform some validation first. if not isinstance(sharding_spec, ChunkShardingSpec): raise ValueError('Only ChunkShardingspec is supported.') if not hasattr(module, param_name): raise ValueError( f'module: {module} does not have parameter with name: {param_name}' ) tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}' ) if not tensor.is_contiguous(): raise ValueError(f'param: {param_name} is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_sizes[ sharding_spec.dim], # type: ignore[union-attr, index] ).clone().detach().contiguous() # Sync requires_grad to local_shard. local_shard.requires_grad = tensor.requires_grad # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg) # Manually set sharding_spec st._sharding_spec = sharding_spec # Replace param with ShardedTensor. # Need to delete the attribute first since param_name might be # torch.nn.Parameter and can't be replaced with ShardedTensor which is # not torch.nn.Parameter. delattr(module, param_name) # Now we can set the attribute appropriately. setattr(module, param_name, st)