def bench_vanilla_all_gather(nelem, device, repeat=100, warm_up=10): torch.manual_seed(1) nranks = dist.get_world_size() rank = dist.get_rank() partition_size = _divup(nelem, nranks) nelem = nranks * partition_size # make sure dividable data_tensor = torch.rand(nelem, device=device, dtype=torch.half) all_gather_output = data_tensor start = rank * partition_size all_gather_input = data_tensor.narrow(0, start, partition_size) time_costs = [] for i in range(repeat + warm_up): _start_time = time.time() dist._all_gather_base(all_gather_output, all_gather_input, group=None) torch.cuda.synchronize() _end_time = time.time() if i >= warm_up: time_costs.append((_end_time - _start_time) * 1e3) if dist.get_rank() == 0: print( f'all-gather from all ranks average costs {np.mean(time_costs)} ms, std {np.std(time_costs)}' )
def _rebuild_full_params(self) -> None: """ Gather all shards of params. """ def update_p_data(output_tensor: torch.Tensor) -> None: """ Helper function to update p.data pointer. Args: output_tensor (torch.Tensor): this tensor contains the data we just gathered. """ p.data = output_tensor # Trim any padding and reshape to match original size. p.data = p.data[:p._orig_size.numel()].view( p._orig_size) # type: ignore[attr-defined] with torch.cuda.stream(self._streams["all_gather"]): for p in self.params: # e.g., when world_size == 1 if not p._is_sharded: # type: ignore[attr-defined] continue # If full param has been rebuilt or has not been freed, no need to call all gather elif (p._full_param_padded.storage().size( ) # type: ignore[attr-defined] == p._full_param_padded.size().numel( ) # type: ignore[attr-defined] ): update_p_data( p._full_param_padded) # type: ignore[attr-defined] continue else: # If full param has not been rebuilt or has been freed, call all gather # Move params in CPU to CUDA for the all-gather. p_data = p.data # type: ignore[attr-defined] p_full_size = p._full_param_padded.size( ) # type: ignore[attr-defined] assert ( p_full_size.numel() == p_data.numel() * self.world_size ), "Param full size should be equal to its shard size multiply world_size." assert ( p._full_param_padded.storage().size() == 0 # type: ignore[attr-defined] ), "Full param's storage should have been freed before if all gather is needed." # type: ignore[attr-defined] # Allocate based on full size from all shards. _alloc_storage( p._full_param_padded, size=p_full_size) # type: ignore[attr-defined] output_tensor = p._full_param_padded # type: ignore[attr-defined] # Fill output_tensor with (p.data for each shard in self.world_size) dist._all_gather_base(output_tensor, p_data, group=self.process_group) # Set p.data = output_tensor (with padding trimmed) update_p_data(output_tensor) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
def timed_all_gather(input, output, args): if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': import deepspeed.comm as dist sync_all() # Warmups, establish connections, etc. for i in range(args.warmups): # use all_gather_base if available if args.dist == 'torch': if hasattr(torch.distributed, "_all_gather_base"): dist._all_gather_base(output, input, group=None, async_op=args.async_op) else: output_tensors = list( torch.chunk(output_tensor, cdb.get_world_size(group))) dist.all_gather(output_tensors, input_tensor, group=group, async_op=True) elif args.dist == 'deepspeed': dist.allgather_fn(output, input, group=None, async_op=args.async_op) sync_all() # time the actual comm op trials times and average it pre = time.perf_counter() for i in range(args.trials): # use all_gather_base if available if args.dist == 'torch': if hasattr(torch.distributed, "_all_gather_base"): dist._all_gather_base(output, input, group=None, async_op=args.async_op) else: output_tensors = list( torch.chunk(output_tensor, cdb.get_world_size(group))) dist.all_gather(output_tensors, input_tensor, group=group, async_op=True) elif args.dist == 'deepspeed': dist.allgather_fn(output, input, group=None, async_op=args.async_op) sync_all() duration = time.perf_counter() - pre # maintain and clean performance data avg_duration = duration / args.trials size = input.element_size() * input.nelement() n = dist.get_world_size() tput, busbw = get_bw('all_gather', size, avg_duration, args) tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration) desc = f'{input.nelement()}x{input.element_size()}' if not args.raw: size = convert_size(size) print_rank_0( f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def bench_hierarchy_all_gather(nelem, device, intra_shard_group, inter_shard_group, repeat=100, warm_up=10): torch.manual_seed(1) nranks = dist.get_world_size() rank = dist.get_rank() local_size = torch.cuda.device_count() local_rank = rank % local_size partition_size = _divup(nelem, nranks) nelem = nranks * partition_size # make sure dividable data_tensor = torch.rand(nelem, dtype=torch.half, device=device) inter_node_size = dist.get_world_size(group=inter_shard_group) inter_node_part_size = inter_node_size * partition_size start = inter_node_part_size * local_rank inter_node_output = data_tensor.narrow(0, start, inter_node_part_size) inter_node_shard_rank = dist.get_rank(group=inter_shard_group) start = inter_node_shard_rank * partition_size inter_node_input = inter_node_output.narrow(0, start, partition_size) # input of intra-node all-gather is output from the inter_node all-gather intra_node_input = inter_node_output # intra-node output produce the complete parameter tensor intra_node_output = data_tensor time_costs = [] for i in range(repeat + warm_up): _start_time = time.time() # inter-node dist._all_gather_base(inter_node_output, inter_node_input, group=inter_shard_group) # intra-node dist._all_gather_base(intra_node_output, intra_node_input, group=intra_shard_group) # device sync torch.cuda.synchronize() _end_time = time.time() if i >= warm_up: time_costs.append((_end_time - _start_time) * 1e3) if dist.get_rank() == 0: print( f'average cost of hierarchy all-gather {np.mean(time_costs)} ms, std {np.std(time_costs)}' )
def _all_gather_sharded_tensor( sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor: if pg is None: pg = distributed_c10d._get_default_group() world_size = dist.get_world_size(pg) shards = sharded_tensor.local_shards() local_tensor = shards[0].tensor.flatten() dim_0_size = sharded_tensor.size()[0] # type: ignore[index] tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] chunk_size = math.ceil( dim_0_size / world_size) * tensor_numel // dim_0_size num_padding = chunk_size - local_tensor.numel() if num_padding > 0: local_tensor = F.pad(local_tensor, [0, num_padding]) tensor = torch.empty(chunk_size * world_size, dtype=local_tensor.dtype).cuda() dist._all_gather_base(tensor, local_tensor, group=pg) return tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
def all_gather_base(self, collectiveArgs, retFlag=False, pair=False): retObj = dist._all_gather_base( output_tensor=collectiveArgs.opTensor, input_tensor=collectiveArgs.ipTensor, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) # synchronicity is maintained in runColl if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def _all_gather_sharded_tensor( sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor: if pg is None: pg = distributed_c10d._get_default_group() world_size = dist.get_world_size(pg) shards = sharded_tensor.local_shards() dim_0_size = sharded_tensor.size()[0] # type: ignore[index] tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] chunk_size = math.ceil( dim_0_size / world_size) * tensor_numel // dim_0_size cuda_device = torch.device("cuda", torch.cuda.current_device()) if shards: local_tensor = shards[0].tensor.flatten() if not local_tensor.is_cuda: move_to_cpu = torch.ones(1, device=cuda_device) local_tensor = local_tensor.cuda() else: move_to_cpu = torch.zeros(1, device=cuda_device) num_padding = chunk_size - local_tensor.numel() if num_padding > 0: local_tensor = F.pad(local_tensor, [0, num_padding]) else: local_tensor = torch.zeros(chunk_size, dtype=sharded_tensor.dtype, device=cuda_device) move_to_cpu = torch.zeros(1, device=cuda_device) tensor = torch.empty( chunk_size * world_size, dtype=local_tensor.dtype, device=cuda_device, ) dist._all_gather_base(tensor, local_tensor, group=pg) tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) return tensor
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): if not input.is_contiguous(memory_format=torch.channels_last): input = input.contiguous() if weight is not None: weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) num_channels = input.shape[1] if input.numel() > 0: # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full( (1,), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device ) # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) else: # for empty input, set stats and the count to zero. The stats with # zero count will be filtered out later when computing global mean # & invstd, but they still needs to participate the all_gather # collective communication to unblock other peer processes. combined = torch.zeros( 2 * num_channels + 1, dtype=input.dtype, device=input.device ) # Use allgather instead of allreduce because count could be different across # ranks, simple all reduce op can not give correct results. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on # all gathered mean, invstd and count. # for nccl backend, use the optimized version of all gather. if process_group._get_backend_name() == 'nccl': # world_size * (2C + 1) combined_size = combined.numel() combined_flat = torch.empty(1, combined_size * world_size, dtype=combined.dtype, device=combined.device) dist._all_gather_base(combined_flat, combined, process_group, async_op=False) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) else: # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for _ in range(world_size) ] dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) # remove stats from empty inputs mask = count_all.squeeze(-1) >= 1 count_all = count_all[mask] mean_all = mean_all[mask] invstd_all = invstd_all[mask] # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1) ) self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) self.process_group = process_group # apply element-wise normalization if input.numel() > 0: return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) else: return torch.empty_like(input)
def _communicate_optim_state( fsdp_module, flat_param: FlatParameter, flat_param_state: Dict[str, Any], to_save: bool, ) -> ConsolidatedOptimState: """ Communicates the optimizer state for a flattened parameter ``flat_param`` across ranks so that the target rank holds the entire non-sharded optimizer state. If ``N`` is the number of tensor optimizer states in the optimizer state dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` otherwise (where the plus 1 comes from all-gathering the padding per rank). Args: flat_param (FlatParameter): The flattened parameter. flat_param_state (Dict[str, Any]): The entry in the "state" part of the optimizer state dict corresponding to the flattened parameter. to_save (bool): Whether to save the state on this rank. Returns: ConsolidatedOptimState: Consolidated optimizer state for ``flat_param``; the state is not populated for non-target ranks. """ param_index = -1 for i, param in enumerate(fsdp_module.params): if param is flat_param: param_index = i break assert param_index >= 0, "`fsdp_module` must own `flat_param`" state = ConsolidatedOptimState() tensor_state, zero_dim_tensor_state, non_tensor_state = \ state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state process_group = fsdp_module.process_group tensor_buffer = None # initialize lazily in case it is not needed for state_name, value in flat_param_state.items(): # Positive-dimension tensor state: communicate across ranks if torch.is_tensor(value) and value.dim() > 0: # If the parameter is not sharded (e.g. world size of 1), then # neither is the positive-dimension tensor state, so no need to # communicate it -- we take the target rank's value if not flat_param._is_sharded: tensor_state[state_name] = value.cpu() continue if tensor_buffer is None: # Assume that positive-dimension tensor optimizer state # has the same shape as the sharded flattened parameter buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] tensor_buffer = value.new_zeros(*buffer_size) dist._all_gather_base(tensor_buffer, value, group=process_group) if to_save: assert hasattr(flat_param, "_orig_size"), \ "Sharded flattened parameter should have `_orig_size` set" unpadded_numel = flat_param._orig_size.numel() # type: ignore[attr-defined] tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu() # Zero-dimension tensor state and non-tensor state: take this rank's # value directly elif to_save: if _is_zero_dim_tensor(value): zero_dim_tensor_state[state_name] = value.cpu() else: non_tensor_state[state_name] = value return state
def forward(ctx, output_tensor, input_tensor, group): ctx.group = group dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) return output_tensor
def _rebuild_full_params(self) -> None: """ Gather all shards of params. """ def update_p_data(output_tensor: torch.Tensor) -> None: """ Helper function to update p.data pointer. Args: output_tensor (torch.Tensor): this tensor contains the data we just gathered. """ p.data = output_tensor # Trim any padding and reshape to match original size. p.data = p.data[:p._orig_size.numel()].view( p._orig_size) # type: ignore[attr-defined] with torch.cuda.stream(self._streams["all_gather"]): for p in self.params: if self.cpu_offload.offload_params: # Move params to GPU if needed. Note that we don't use # self._full_param_padded.device here because the attr is # not set always, i.e. when world_size=1 and # p._is_sharded = False. However when it is set, the # device is always self.compute_device. p.data = p.data.to(self.compute_device, non_blocking=True) # e.g., when world_size == 1 if not p._is_sharded: # type: ignore[attr-defined] continue # If full param has been rebuilt or has not been freed, no need to call all gather elif (p._full_param_padded.storage().size( ) # type: ignore[attr-defined] == p._full_param_padded.size().numel( ) # type: ignore[attr-defined] ): update_p_data( p._full_param_padded) # type: ignore[attr-defined] continue else: # If full param has not been rebuilt or has been freed, call all gather p_data = p.data # type: ignore[attr-defined] p_full_size = p._full_param_padded.size( ) # type: ignore[attr-defined] assert ( p_full_size.numel() == p_data.numel() * self.world_size ), "Param full size should be equal to its shard size multiply world_size." assert ( p._full_param_padded.storage().size() == 0 # type: ignore[attr-defined] ), "Full param's storage should have been freed before if all gather is needed." # type: ignore[attr-defined] # Allocate based on full size from all shards. _alloc_storage( p._full_param_padded, size=p_full_size) # type: ignore[attr-defined] output_tensor = p._full_param_padded # type: ignore[attr-defined] # Fill output_tensor with (p.data for each shard in self.world_size) dist._all_gather_base(output_tensor, p_data, group=self.process_group) # Set p.data = output_tensor (with padding trimmed) update_p_data(output_tensor) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): if not input.is_contiguous(memory_format=torch.channels_last): input = input.contiguous() if weight is not None: weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: raise ValueError( 'Expected more than 1 value per channel when training, got input size {}' .format(size)) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full((1, ), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # Use allgather instead of allreduce because count could be different across # ranks, simple all reduce op can not give correct results. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on # all gathered mean, invstd and count. # for nccl backend, use the optimized version of all gather. if process_group._get_backend_name() == 'nccl': # world_size * (2C + 1) combined_size = combined.numel() combined_flat = torch.empty(1, combined_size * world_size, dtype=combined.dtype, device=combined.device) dist._all_gather_base(combined_flat, combined, process_group, async_op=False) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) else: # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for k in range(world_size) ] dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1)) self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out