def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of col-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) When the local tensor only has one dimension, we increase one more dimension for reshard. We need to do squeeze manually to reduce the dimension later-on. For example, if we have: input: size[15] weight: size[15, 16] world_size: 4 In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4] tensor and generate a sharded tenor sharded by dim 1. For the rest situations, we just simply concatenate local tensors. No more actions are needed afterward. Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`ShardedTensor` object which filled with local intermediate results. """ # allgather the inputs first. gathered_inputs = all_gather(input, group=pg) (start_pos, chunk_size) = get_chunk_sharding_params(bias.size(0), world_size, weight._sharding_spec, rank) local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size, weight, pg, bias) results = [None] * world_size indices = {} for idx, placement in enumerate(weight._sharding_spec.placements): indices[placement.rank()] = idx for i, inp in enumerate(gathered_inputs): results[indices[i]] = inp.matmul(local_shard_t) + local_bias # When the local result only has one dimension, we need to make sure # it does not shard by dim 0. So reshard can work properly. if results[0].dim() == 1: # type: ignore[attr-defined] result = torch.stack(results) # type: ignore[arg-type] else: result = torch.cat(results) # type: ignore[arg-type] return _init_sharded_tensor_from_local_result( weight, result, 0, -1, world_size, pg # type: ignore[arg-type] )
def _handle_col_wise_sharding(input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg): """ Entry-point function to handle the logic of col-wise sharding of weight for embedding. (Detailed explanations of the logic can be found in the comment for sharded_embedding.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: col-wise shared local weight used for lookup. max_norm: If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type: The p in the p-norm to compute for the max_norm option. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. pg: process group. Returns: final result of lookup. """ if not isinstance(input, ReplicatedTensor): # allgather the inputs first for non Replicated Tensor. gathered_inputs = all_gather(input, group=pg) else: gathered_inputs = input if max_norm is not None: # max_norm changes the weight in-place local_shard = _handle_max_norm_col_wise(max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg) output = _handle_col_wise_sharding_base( torch.nn.functional.embedding, len(input.size()), input, world_size, weight, local_shard, pg, gathered_inputs, padding_idx=padding_idx, ) return (output, local_shard)
def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg): """ In case we need to gather input and all other parameters of embeddingBag ops, we need to stack all input together to perform ``all_gather`` collective communication just once. Note that since offsets does not share the same size as input and is always smaller than input, we resize it during the communication. Args: input: tensor to be applied op on. per_sampe_weights: weights for weighted sum mode. offsets: when input is 1D. offsets determines the starting index position of each bag (sequence) in input. pg: process group. Returns: gathered_inputs: list of input tensor gathered from each rank. gathered_per_sample_weights: list of per_sample_weights from each rank. gathered_offsets: list of offsets from each rank. """ input_to_gather = [input] if per_sample_weights is not None: input_to_gather.append(per_sample_weights) if offsets is not None: input_to_gather.append(offsets.clone().resize_(input.size())) gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg) gathered_per_sample_weights = None if per_sample_weights is not None: gathered_per_sample_weights = [t[1] for t in gathered_inputs] gathered_offsets = None if offsets is not None: idx = 2 if per_sample_weights is not None else 1 gathered_offsets = [ t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs ] gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs] return gathered_inputs, gathered_per_sample_weights, gathered_offsets
def _handle_col_wise_sharding_base( op_func, sharding_dim_size, col_dim, input, world_size, weight, local_shard, pg, gathered_inputs=None, mode=None, gathered_per_sample_weights=None, gathered_offsets=None, padding_idx=None, ): """ For col-wise sharding of weight, lots of logic are common. So we extract the common logic and put in this function: Step 1. To get input from each rank and Step 2. To perform the op on the concatenated tensor. Step 3. To distribute results to each rank with col rearrangement. Step 4. To concatenate all results from all ranks. Args: op_func: operator which is applied to the input tensor. sharding_dim_size: the max size of the column each rank gets. col_dim: dim of result tensor after the operation. input: tensor to be applied op on. world_size: number of ranks. weight: shareded weight tensor. local_shard: col-wise sharded weight tensor. pg: process group. gathered_inputs: list of inputs from all ranks. If specified, we don't need to communicate with each rank any more. mode: aggregation mode of EmbeddingBag. gathered_per_sample_weights: per_sample_weights across all ranks. gathered_offsets: offsets across all ranks. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. Note that the embedding vector at padding_idx is excluded from the reduction. Return: final result of input being applied with the op. """ if gathered_inputs is None: # allgather the inputs first. gathered_inputs = all_gather(input, group=pg) # run the operator's function for all the inputs. results = [] for i, inp in enumerate(gathered_inputs): if op_func == torch.nn.functional.embedding_bag: result = op_func( inp, local_shard, offsets=gathered_offsets[i] if gathered_offsets is not None else None, mode=mode, per_sample_weights=gathered_per_sample_weights[i] if gathered_per_sample_weights is not None else None, padding_idx=padding_idx, ) elif op_func == torch.nn.functional.embedding: result = op_func( inp, local_shard, padding_idx=padding_idx, ) else: result = op_func(inp, local_shard) results.append(torch.transpose(result, 0, col_dim)) # Distribute results to each rank with col rearrangement. output = _result_distribute_with_col_rearrange(results, input, sharding_dim_size, world_size, weight, pg) # transpose the output and return result. return torch.transpose(output, 0, col_dim)