def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg): """ For col-wise sharding of weight, we need to distribute results to each rank. We do them in this function. Note that, if the index in the Sharding Spec is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). Args: results: results from ops applied to inputs from all ranks. We need to distribute them back to their original ranks. input: tensor to be applied op to. world_size: number of ranks. weight: shareded weight tensor. pg: process group. Return: column rearranged result. """ # Process results and outputs for all2all. sharding_dim = weight._sharding_spec.dim sharding_dim_size = weight.size(sharding_dim) dims = list(results[0].size()) dims[0] = sharding_dim_size combined_results = torch.cat(results) output = torch.empty(*dims, device=combined_results.device, dtype=combined_results.dtype) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [0] * world_size for idx, placement in enumerate(weight._sharding_spec.placements): output_split_sizes[placement.rank()] = get_chunked_dim_size( sharding_dim_size, split_size, idx) # distribute the outputs using all2all. output = all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange columns appropriately for output. rearrange_columns = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if not rearrange_columns: return output indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) return output.index_select(0, torch.tensor(indices, device=output.device))
def reshuffle_local_shard( local_shard: torch.Tensor, st_size: torch.Size, sharding_spec: ShardingSpec, resharding_spec: ShardingSpec, pg: ProcessGroup, ) -> Tuple[List[Shard], List[ShardMetadata]]: """ Reshuffle the local shard directly when the reshard dim is same as the original sharding dim. Logically we do this in two step: 1. To collect all shards based on original sharding spec. 2. Reshard the tensor based on the given resharding spec. In reality, we consolidate the two steps into one by sending the local tensor to the new shard directly based on the resharding spec. Args: local_tensor (Tensor): Local tensor stored in the current rank. st_size (torch.Size): The size of the sharded tensor. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor is sharded originally. resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how the tensor will be resharded. pg (ProcessGroup): The process group to aggregate on. Returns: A Tuple of the followings: A List[`Shard`] which contains the local tensor and its metadata. A List[`ShardMetadata`] which contains the metadata for the shard, including offsets, lengths and device placement. """ current_rank = dist.get_rank(pg) world_size = dist.get_world_size(pg) # Build shards_metadata first. shards_metadata, ranks = build_reshard_metadata(st_size, resharding_spec, world_size) # Get input split size for all2all. reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] split_size = get_split_size(st_size[reshard_dim], world_size) input_split_sizes = [0] * world_size idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] new_rank = resharding_spec.placements[idx].rank( ) # type: ignore[union-attr, attr-defined] input_split_sizes[new_rank] = local_shard.size(reshard_dim) # Get output split size for all2all. output_split_sizes = [0] * world_size new_idx = ranks.index(current_rank) sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) output_split_sizes[new_rank] = sharded_dim_size # Get gathered_input for all2all. local_shard = local_shard.transpose(0, reshard_dim).contiguous() gathered_input_size = list(local_shard.size()) gathered_input_size[0] = sharded_dim_size gathered_input = torch.empty(gathered_input_size, device=local_shard.device) # all2all. local_shard = all_to_all_single( gathered_input, local_shard, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=pg, ) local_tensor = local_shard.transpose(0, reshard_dim).contiguous() local_shards = [Shard(local_tensor, shards_metadata[current_rank])] return local_shards, shards_metadata
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) local_result = reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def _handle_row_wise_sharding_tensor(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ # alltoall to gather all the appropriate inputs. input_t = input.transpose(0, -1).contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input_size = [input_split_sizes[rank] * world_size] + list( input_t_size[1:]) gathered_input = torch.empty(gathered_input_size, device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) # Reshape gathered_input appropriately for matmul shard_size = local_shard_t.size()[0] reshaped_inputs = [ torch.narrow(gathered_input, 0, r * shard_size, shard_size).transpose(0, -1) for r in range(world_size) ] reshaped_input = torch.cat(reshaped_inputs) if reshaped_input.dim() == 1: reshaped_input = reshaped_input.view(-1, local_shard_t.size(0)) # Perform appropriate local matmul if reshaped_input.dim() <= 2: result = torch.addmm(_BiasTensorPartial.apply(world_size, bias), reshaped_input, local_shard_t) else: result = reshaped_input.matmul( local_shard_t) + _BiasTensorPartial.apply(world_size, bias) # Return the partial local result. return _PartialTensor(result, pg)