def reshard(self, resharding_spec: ShardingSpec) -> ShardedTensor: """ The reshard happens in two steps logically: 1. Aggregate all the shards of the partial tensor. 2. Shard this tensor according to the provided spec. In reality, for the sake of performance, we consolidate all partial tensors across multiple ranks and covert to a sharded tensor in one step. Args: resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how we reshard the aggregated local result. Returns: A :class:`ShardedTensor` filled with local aggregated result. """ if not isinstance(resharding_spec, ChunkShardingSpec): raise NotImplementedError( "Only ChunkShardingSpec supported for reshard.") sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined] if self.local_shard.size( sharding_dim) % self.process_group.size() != 0: raise ValueError( 'World size need to divide the length of the dimension.') if self.local_shard.is_complex(): raise NotImplementedError( "Only real partial tensor supported for reshard.") local_shards = self.local_shard.chunk(self.process_group.size(), dim=sharding_dim) local_result = reduce_scatter(torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op) sharded_tensor_size = self.local_shard.size() current_offsets = [0] * len(local_result.size()) shards = [] rank = self.process_group.rank() for idx, placement in enumerate( resharding_spec.placements): # type: ignore[attr-defined] if rank == placement.rank(): # type: ignore[union-attr] local_metadata = ShardMetadata( shard_offsets=current_offsets, shard_sizes=list(local_result.size()), placement=placement, ) shards.append(Shard(local_result, local_metadata)) break current_offsets[sharding_dim] += local_result.size( sharding_dim) # type: ignore[index] st = ShardedTensor._init_from_local_shards( shards, tuple(sharded_tensor_size), process_group=self.process_group) st._sharding_spec = copy.deepcopy(resharding_spec) return st
def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: """ The reshard happens in two steps logically: 1. Aggregate all the shards of the partial tensor. 2. Shard this tensor according to the provided spec. In reality, for the sake of performance, we consolidate all partial tensors across multiple ranks and covert to a sharded tensor in one step. Args: resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how we reshard the aggregated local result. Returns: A :class:`ShardedTensor` filled with local aggregated result. """ if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec): raise NotImplementedError( "Only ChunkShardingSpec supported for reshard.") if self.local_shard.is_complex(): raise NotImplementedError( "Only real partial tensor supported for reshard.") sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined] chunk_mode_res = self.local_shard.size( sharding_dim) % self.process_group.size() # Add padding when the size is not divisible by the world size. if chunk_mode_res != 0: padding = [0] * (len(self.local_shard.size()) * 2) padding[-1] = self.process_group.size() - chunk_mode_res self.local_shard = torch.nn.functional.pad( self.local_shard, tuple(padding), "constant", 0, ) local_shards = self.local_shard.chunk(self.process_group.size(), dim=sharding_dim) local_result = reduce_scatter(torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op) sharded_tensor_size = self.local_shard.size() return ShardedTensor._init_from_local_tensor( local_result, resharding_spec, sharded_tensor_size, process_group=self.process_group, )
def reshard(self, resharding_spec: ShardingSpec) -> ShardedTensor: """ The reshard happens in two steps logically: 1. Aggregate all the shards of the partial tensor. 2. Shard this tensor according to the provided spec. In reality, for the sake of performance, we consolidate all partial tensors across multiple ranks and covert to a sharded tensor in one step. Args: resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how we reshard the aggregated local result. Returns: A :class:`ShardedTensor` filled with local aggregated result. """ if not isinstance(resharding_spec, ChunkShardingSpec): raise NotImplementedError( "Only ChunkShardingSpec supported for reshard.") sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined] if self.local_shard.size( sharding_dim) % self.process_group.size() != 0: raise ValueError( 'World size need to divide the length of the dimension.') if self.local_shard.is_complex(): raise NotImplementedError( "Only real partial tensor supported for reshard.") local_shards = self.local_shard.chunk(self.process_group.size(), dim=sharding_dim) local_result = reduce_scatter(torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op) sharded_tensor_size = self.local_shard.size() return ShardedTensor._init_from_local_tensor( local_result, resharding_spec, sharded_tensor_size, process_group=self.process_group, )
def _handle_row_wise_sharding(input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for embedding. (Detailed explanations of the logic can be found in the comment for sharded_embedding.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: row-wise shared local weight used for lookup. max_norm: If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type: The p in the p-norm to compute for the max_norm option. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. rank: # of cuda process. pg: process group. Returns: final result of lookup. """ if not isinstance(input, ReplicatedTensor): # allgather the inputs first for non Replicated Tensor. gather_inp = _all_gather_base_input(input, pg) else: gather_inp = input # Mask the input according to sharding spec. lookup_input, padding_idx, padding_row = _handle_row_wise_mask( gather_inp, padding_idx, weight, world_size, rank) # When input is a large tensor, the value of weight is changed. # This is a walk-around for now. GH issue: #81717 if max_norm is not None: torch.nn.functional.embedding( torch.unique(lookup_input)[:-1], local_shard, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, ) max_norm = None local_input_embeddings = torch.nn.functional.embedding( lookup_input, torch.cat([local_shard, padding_row]), padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, ) # TODO: Make the result a PartialTensor. if isinstance(input, ReplicatedTensor): return all_reduce(local_input_embeddings, group=pg) else: local_shards = local_input_embeddings.chunk(pg.size()) return reduce_scatter( torch.empty_like(local_shards[0]), list(local_shards), group=pg, )
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform autograd enabled alltoall all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) local_result = reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def _handle_row_wise_sharding( input, world_size, weight, local_shard, offsets, per_sample_weights, mode, max_norm, norm_type, padding_idx, rank, pg, ): """ Entry-point function to handle the logic of row-wise sharding of weight for embeddingBag. (Detailed explanations of the logic can be found in the comment for sharded_embedding_bag.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: row-wise shared local weight used for lookup. offsets: list of start positions of each bag for 1D input. per_sample_weights: weights for weighted sum mode. mode: aggregation method of each bag. max_norm: If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type: The p in the p-norm to compute for the max_norm option. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. Note that the embedding vector at padding_idx is excluded from the reduction. rank: # of cuda process. pg: process group. Returns: gathered_output: final result of lookup and aggregation. """ if not isinstance(input, ReplicatedTensor): if input.dim() > 1 and per_sample_weights is None: # allgather the inputs first for non Replicated Tensor. gather_inp = _all_gather_base_input(input, pg) else: ( gathered_inputs, gathered_per_sample_weights, gathered_offsets, ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) cat_dim = 0 if input.dim() != 1 else -1 gather_inp = torch.cat(gathered_inputs, dim=cat_dim) if per_sample_weights is not None: per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim) offset_add = 0 if input.dim() > 1 else input.size(0) if offsets is not None: offsets_list = torch.cat( [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())], dim=cat_dim, ) else: gather_inp = input # Mask the input according to sharding spec. lookup_input, padding_local, padding_row = _handle_row_wise_mask( gather_inp, padding_idx, weight, world_size, rank ) if mode == "max": padding_row[:] = -float("Inf") # When input is a large tensor, the value of weight is changed. # This is a walk-around for now. GH issue: #81717. if max_norm is not None: torch.nn.functional.embedding_bag( torch.unique(lookup_input)[:-1], local_shard, offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long), mode=mode, per_sample_weights=None, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_local, ) max_norm = None result = torch.nn.functional.embedding_bag( lookup_input, torch.cat([local_shard, padding_row]), offsets=offsets_list if offsets is not None else offsets, mode=mode if mode != "mean" else "sum", per_sample_weights=per_sample_weights, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_local, ) op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX # TODO: Make the result a PartialTensor and move the the logic below there. if isinstance(input, ReplicatedTensor): result = all_reduce(result, op=op, group=pg) else: local_shards = result.chunk(pg.size()) result = reduce_scatter( torch.empty_like(local_shards[0]), list(local_shards), op=op, group=pg, ) # For Mean, we cannot do the division until very end because the sum of means # not equal to the mean of sum. (Divisor is different) if mode == "mean": if input.dim() > 1: padding_idx = padding_idx if padding_idx is not None else -1 split_sizes = torch.sum( torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype ) else: split_sizes = torch.cat( ( offsets[1 : offsets.size(0)] - offsets[0:-1], (input.size(0) - offsets[-1]).unsqueeze(0), ), dim=-1, ) return torch.div(result, split_sizes.unsqueeze(1)) # Return the appropriate local result. return result
def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: """ The reshard happens in two steps logically: 1. Aggregate all the shards of the partial tensor. 2. Shard this tensor according to the provided spec. In reality, for the sake of performance, we consolidate all partial tensors across multiple ranks and covert to a sharded tensor in one step. Args: resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how we reshard the aggregated local result. Returns: A :class:`ShardedTensor` filled with local aggregated result. """ if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec): raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") if self.local_shard.is_complex(): raise NotImplementedError("Only real partial tensor supported for reshard.") sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined] chunk_mode_res = self.local_shard.size(sharding_dim) % self.process_group.size() local_shard = self.local_shard # Add padding when the size is not divisible by the world size. if chunk_mode_res != 0: padding = [0] * (self.local_shard.dim() * 2) padding[-1] = self.process_group.size() - chunk_mode_res local_shard = torch.nn.functional.pad( self.local_shard, tuple(padding), "constant", 0, ) current_rank = dist.get_rank(self.process_group) # type: ignore[attr-defined] rank_idx = None rearrange_local_shards = False indices = [0] * self.process_group.size() for idx, placement in enumerate(resharding_spec.placements): # type: ignore[attr-defined] if placement.rank() == current_rank: # type: ignore[index, union-attr] rank_idx = idx # type: ignore[attr-defined] if placement.rank() != idx: # type: ignore[index, union-attr] rearrange_local_shards = True indices[placement.rank()] = idx # type: ignore[index, union-attr] local_shards = local_shard.chunk(self.process_group.size(), dim=sharding_dim) if rearrange_local_shards: # Need to re-arrange original shard_dim of output_tensor_list. local_shards = [local_shards[idx] for idx in indices] # type: ignore[call-overload] local_result = reduce_scatter( torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op ) sharded_tensor_size = self.local_shard.size() # Remove padding when the size is not divisible by the world size. if chunk_mode_res != 0: uneven_local_shards = self.local_shard.chunk( self.process_group.size(), dim=sharding_dim ) expected_size = uneven_local_shards[rank_idx].size() if local_result.size() != expected_size: local_result = local_result.narrow( sharding_dim, 0, expected_size[sharding_dim], ) return ShardedTensor._init_from_local_tensor( local_result, resharding_spec, sharded_tensor_size, process_group=self.process_group, )