def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[int] = [] for placement in weight._sharding_spec.placements: sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, placement.rank()) input_idx = placement.rank() * split_size indices += range(input_idx, input_idx + sharded_dim_size) input_t = input_t.index_select( 0, torch.tensor(indices, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def generate_local_weight_sharding_params_for_test(local_weight, sharded_dim, gpu_num, spec, rank): """ Shard the local weight based the given spec, so we can compare against the one from sharded tensor. Args: local_weight: weight matrix to be sharded. sharded_dim: The dimension which we shard on. gpu_num: number of ranks. spec: shareding spec. rank: # of cuda process. Returns: start_pos: start position of sharded weight on the given rank. chunk_size: chunk size of sharded weight on the given rank. """ sharding_dim_size = local_weight.size(sharded_dim) split_size = get_split_size(sharding_dim_size, gpu_num) current_offsets = 0 start_pos = current_offsets for idx, placement in enumerate(spec.placements): chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) if rank == placement.rank(): start_pos = current_offsets break current_offsets += chunk_size return start_pos, chunk_size
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def _init_chunked( self, dims, tensor_init_params: TensorInitParams, ): current_rank = dist.get_rank(self._process_group) sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined] # Validate the sharding spec. if not isinstance(sharding_dim, int): raise ValueError( f"Sharding dim needs to be an integer, found: {sharding_dim}") if sharding_dim >= len(dims) or sharding_dim < -len(dims): raise ValueError(f"Invalid sharding dim: {sharding_dim}") dim_size = dims[sharding_dim] remote_devices = self._sharding_spec.placements # type: ignore[attr-defined] chunks = len(remote_devices) # split_size computed similar to 'torch.chunk' split_size = get_split_size(dim_size, chunks) shards_metadata = [] for idx, remote_device in enumerate(remote_devices): rank, local_device = _parse_and_validate_remote_device( self._process_group, remote_device) # Adjust the sharding dim for this rank. sharded_dim_size = get_chunked_dim_size(dim_size, split_size, idx) if sharded_dim_size > 0: # Build sharding_metadata. # deepcopy for modification. rank_dims = dims.copy() rank_offsets = [0] * len(dims) rank_offsets[sharding_dim] = split_size * idx rank_dims[sharding_dim] = sharded_dim_size shard_metadata = ShardMetadata(rank_offsets, rank_dims, remote_device) shards_metadata.append(shard_metadata) # Build the local shard for the current rank if it is involved in the sharding spec. if current_rank == rank: # Initialize the local shard. local_shard = _create_tensor_from_params( *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append( Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, tensor_init_params.tensor_properties, )
def _result_distribute_with_col_rearrange(results, input, sharding_dim_size, world_size, weight, pg): """ For col-wise sharding of weight, we need to distribute results to each rank. We do them in this function. Note that, if the index in the Sharding Spec is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). Args: results: results from ops applied to inputs from all ranks. We need to distribute them back to their original ranks. input: tensor to be applied op to. sharding_dim_size: the max size of the column each rank gets. world_size: number of ranks. weight: shareded weight tensor. pg: process group. Return: column rearranged result. """ # Process results and outputs for all2all. dims = list(results[0].size()) dims[0] = sharding_dim_size output = torch.empty(*dims, device=input.device) combined_results = torch.cat(results) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [0] * world_size for idx, placement in enumerate(weight._sharding_spec.placements): output_split_sizes[placement.rank()] = get_chunked_dim_size( sharding_dim_size, split_size, idx) # distribute the outputs using all2all. output = all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange columns appropriately for output. rearrange_columns = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if not rearrange_columns: return output indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) return output.index_select(0, torch.tensor(indices, device=output.device))
def _handle_col_wise_sharding(input, world_size, weight, local_shard_t, bias, pg): # allgather the inputs first. gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)] dist.all_gather(gathered_inputs, input, group=pg) # matmul all the inputs. results = [] for i, inp in enumerate(gathered_inputs): results.append(inp.matmul(local_shard_t).t()) # Process inputs and outputs for all2all. sharding_dim_size = weight.size()[0] output = torch.empty((sharding_dim_size, input.size(0)), device=input.device) combined_results = torch.cat(results) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [ get_chunked_dim_size(sharding_dim_size, split_size, placement.rank()) for placement in weight._sharding_spec.placements ] # distribute the outputs using all2all. dist.all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange rows appropriately for output. rearrange_rows = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if rearrange_rows: indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) output = output.index_select( 0, torch.tensor(indices, device=output.device)) # add bias and return result. return output.t() + bias
def shard_parameter(module: torch.nn.Module, param_name: str, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. This method replaces ``module.param_name`` with a :class:`torch.distributed._sharded_tensor.ShardedTensor` Args: module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. param_name (str): Name of the parameter of ``module`` that needs to be sharded. sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. .. warning:: Only :class:`torch.distributed._sharding_spec.ShardingSpec` is currently supported as the ``sharding_spec``. """ # Perform some validation first. if not isinstance(sharding_spec, ChunkShardingSpec): raise ValueError('Only ChunkShardingspec is supported.') if not hasattr(module, param_name): raise ValueError( f'module: {module} does not have parameter with name: {param_name}' ) tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}' ) if not tensor.is_contiguous(): raise ValueError(f'param: {param_name} is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size with torch.cuda.device(tensor.device): dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_lengths=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_lengths[ sharding_spec.dim], # type: ignore[union-attr, index] ).contiguous() # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=shards_metadata, size=tensor.size(), tensor_properties=TensorProperties( dtype=local_shard.dtype, layout=local_shard.layout, requires_grad=local_shard.requires_grad, memory_format=torch.contiguous_format, pin_memory=local_shard.is_pinned(), )) st = ShardedTensor._init_from_local_shards(local_shards, sharded_tensor_metadata, process_group=pg) # Manually set sharding_spec st._sharding_spec = sharding_spec # Replace param with ShardedTensor. # Need to delete the attribute first since param_name might be # torch.nn.Parameter and can't be replaced with ShardedTensor which is # not torch.nn.Parameter. delattr(module, param_name) # Now we can set the attribute appropriately. setattr(module, param_name, st)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) local_result = reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard, pg): # flatten the ids across all input and sort input_size = input.size() input_1d = torch.reshape(input, (-1, )).contiguous() input_sorted, indices_1d = torch.sort(input_1d) rearrange_indices_1d = torch.argsort(indices_1d) input_sorted.contiguous() # Decide which rank the input goes to by check the sharding range. split_size = get_split_size(weight.size(0), world_size) rearrange_rows = False input_split_sizes: List[int] = [0] * world_size input_split_start_indices: List[int] = [0] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0) for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx) start_row_idx = idx * sharded_dim_size_max end_row_idx = start_row_idx + sharded_dim_size start_idx = torch.searchsorted(input_sorted, start_row_idx).item() end_idx = torch.searchsorted(input_sorted, end_row_idx).item() input_split_sizes[placement.rank()] = int(end_idx - start_idx) input_split_start_indices[placement.rank()] = int(start_idx) if placement.rank() != idx: rearrange_rows = True rearrange_indices_1d_second_order = None if rearrange_rows: # Need to re-arrange the 1D tensor to be sent via all2all. indices: List[List[int]] = [[0]] * world_size for placement in weight._sharding_spec.placements: split_length = input_split_sizes[placement.rank()] offset_idx = input_split_start_indices[placement.rank()] indices[placement.rank()] = list( range(offset_idx, offset_idx + split_length)) indices_flatten = list(idx for indice in indices for idx in indice) input_sorted = input_sorted.index_select( 0, torch.tensor(indices_flatten, device=input.device)) rearrange_indices_1d_second_order = torch.argsort( torch.Tensor(indices_flatten)) # Get the input split size to be sent from each rank to the current rank. # We can then infer the output split size. input_split_sizes_tensor = ( torch.Tensor(input_split_sizes).type("torch.IntTensor").cuda(rank)) output_split_sizes_tensor = torch.empty(world_size, dtype=torch.int32, device=input.device) dist.all_to_all_single( output_split_sizes_tensor, input_split_sizes_tensor, group=pg, ) output_split_sizes = output_split_sizes_tensor.tolist() # Input sent from each rank to the current rank may have different sizes. gathered_input = torch.empty(sum(output_split_sizes), dtype=torch.int64, device=input.device) # Perform the modular operation of the 1D tensor to be sent to each rank. input_sorted = torch.remainder(input_sorted, sharded_dim_size_max) # Perform alltoall dist.all_to_all_single( gathered_input, input_sorted, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=pg, ) # Perform local embedding look up. gathered_input_embeddings = torch.nn.functional.embedding( gathered_input, local_shard) # Gather all lookup result appropriately by performing alltoall again gathered_output = torch.empty(input_sorted.size(0), weight.size(1), device=input.device) dist.all_to_all_single( gathered_output, gathered_input_embeddings, input_split_sizes=output_split_sizes, output_split_sizes=input_split_sizes, group=pg, ) # Rearrange the results to its original shape. if rearrange_indices_1d_second_order is not None: gathered_output = gathered_output[rearrange_indices_1d_second_order] gathered_output = gathered_output[rearrange_indices_1d] # Return the appropriate local result. return torch.reshape(gathered_output, (*input_size, weight.size(1)))
def test_get_chunked_dim_size(self): self.assertEqual(3, get_chunked_dim_size(11, 3, 0)) self.assertEqual(2, get_chunked_dim_size(11, 3, 3)) self.assertEqual(4, get_chunked_dim_size(13, 4, 0)) self.assertEqual(1, get_chunked_dim_size(13, 4, 3))
def _handle_row_wise_lookup_distribute(input_sorted, input, world_size, weight, rank, padding_idx): """ In the circumstance of row-wise sharding of weight, we need to distribute the sorted lookup IDs of embedding/embeddingBag to each rank. If the index in the placement is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). In addition, we do two things for padding_idx. The first thing is to only set it if it's within the range of the current rank and the other thing is to do the modularization of it by sharded_dim_size_max. Args: input_sorted: sorted lookup IDs of embedding/embeddingBag. input: tensor to be applied op on. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient and reduction. Return: input_sorted: sorted lookup IDs of embedding/embeddingBag Rearrangement performed if it is needed. input_split_sizes: size of IDs to be assigned to each rank. sharded_dim_size_max: the max size of the row each rank gets. input_split_rearrange_indices: indices of row rearrangement. rearrange_indices_1d_second_order: reverse indices of row rearrangement, which will be used to restore the original order. padding_idx: Same as input if padding_idx is within the range of the given rank; otherwise, None is returned. It is also modularized by sharded_dim_size_max. """ # Decide which rank the input goes to by check the sharding range. split_size = get_split_size(weight.size(0), world_size) rearrange_rows = False indices_flatten = None input_split_sizes: List[int] = [0] * world_size input_split_start_indices: List[int] = [0] * world_size start_row_idx_rank = None end_row_idx_rank = None # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0) for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx) start_row_idx = idx * sharded_dim_size_max end_row_idx = start_row_idx + sharded_dim_size start_idx = torch.searchsorted(input_sorted, start_row_idx).item() end_idx = torch.searchsorted(input_sorted, end_row_idx).item() input_split_sizes[placement.rank()] = int(end_idx - start_idx) input_split_start_indices[placement.rank()] = int(start_idx) if placement.rank() != idx: rearrange_rows = True # Store the range of the current rank. if placement.rank() == rank: start_row_idx_rank = start_row_idx end_row_idx_rank = end_row_idx # Perform the modular if padding_idx is within the range. if padding_idx is not None: if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank: padding_idx = None else: padding_idx = padding_idx % sharded_dim_size_max rearrange_indices_1d_second_order = None if rearrange_rows: # Need to re-arrange the 1D tensor to be sent via all2all. indices: List[List[int]] = [[0]] * world_size for placement in weight._sharding_spec.placements: split_length = input_split_sizes[placement.rank()] offset_idx = input_split_start_indices[placement.rank()] indices[placement.rank()] = list( range(offset_idx, offset_idx + split_length)) indices_flatten = list(idx for indice in indices for idx in indice) input_sorted = input_sorted.index_select( 0, torch.tensor(indices_flatten, device=input.device)) rearrange_indices_1d_second_order = torch.argsort( torch.Tensor(indices_flatten)) return ( input_sorted, input_split_sizes, sharded_dim_size_max, torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None, rearrange_indices_1d_second_order, padding_idx, )