def _init_from_local_tensor( cls, local_tensor: torch.Tensor, sharding_spec: shard_spec.ShardingSpec, *global_size: Sequence[int], process_group: dist.ProcessGroup = None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor given only one local tensor, global sharded tensor size and sharding spec on each rank. Args: local_tensor (Tensor): Single tensor of local shard stored in each rank. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. global_size (Sequence[int]): Size of the sharded tensor. process_group (ProcessGroup, optional): The process group to aggregate on. Default: None init_rrefs (bool, optional): Whether or not to initialize :class:`torch.distributed.rpc.RRef`s pointing to remote shards. Need to initialize the RPC Framework if specified as ``True``. Default: ``False``. Returns: A :class:`ShardedTensor` sharded based on the given sharding_spec with local tensor stored in the current rank. Examples: >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) >>> local_tensor tensor([[1, 2, 3, 4]]) # Rank 0 tensor([[3, 4, 5, 6]]) # Rank 1 >>> sharding_dim = 0 >>> sharding_spec = ChunkShardingSpec( dim=sharding_dim, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) >>> st ShardedTensor( ShardedTensorMetadata( shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), ], size=torch.Size([2, 4]) ) >>> st.local_tensor() tensor([1, 2, 3, 4]) # Rank 0 tensor([3, 4, 5, 6]) # Rank 1 Warning: This API is experimental and subject to change. It lacks of a fully across rank validations, and we only validate the local shard on the current rank. We fully rely on the user to ensure local tensor is sharded based on the sharding spec. """ if not local_tensor.is_contiguous(): raise ValueError('local_tensor is not a contiguous Tensor.') global_tensor_size = _flatten_tensor_size(global_size) tensor_properties = TensorProperties( dtype=local_tensor.dtype, layout=local_tensor.layout, requires_grad=local_tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=local_tensor.is_pinned()) sharded_tensor_metadata = sharding_spec.build_metadata( global_tensor_size, tensor_properties) process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) local_shards: List[Shard] = [] for shard_metadata in sharded_tensor_metadata.shards_metadata: rank, device = _parse_and_validate_remote_device( process_group, shard_metadata.placement) if rank == current_rank: local_shards.append(Shard(local_tensor, shard_metadata)) # TODO: figure out what the API should behave when some rank have no shard # see https://github.com/pytorch/pytorch/issues/7313 return ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, sharded_tensor_metadata, process_group=process_group, init_rrefs=init_rrefs, sharding_spec=sharding_spec, )
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] if world_size > 1: gathered_metadatas = [None for _ in range(world_size)] if isinstance(process_group, dist.ProcessGroupNCCL): # with GPU/NCCL, we need to set a device for all_gather_object # to use as we need to know which device we should put the # serialized tensor on before the NCCL collective. with torch.cuda.device(current_rank): dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) else: dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) else: gathered_metadatas = [local_sharded_tensor_metadata] global_sharded_tensor_metadata = build_global_metadata( gathered_metadatas) # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # add to metadata and local_shards sharded_tensor._metadata = global_sharded_tensor_metadata sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec( global_sharded_tensor_metadata.shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ): """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def process_group(self) -> Optional[ProcessGroup]: if self._process_group is None: # The strategy should have already initilized process group in setup_environment() self._process_group = _get_default_group() return self._process_group
def __setstate__(self, state): # If serializable, then the process group should be the default one self.process_group = _get_default_group() self.check_previous_reduction = False super(DistributedDataParallel, self).__setstate__(state) self._ddp_init_helper()
def shard_parameter( module: torch.nn.Module, param_name: str, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. This method replaces ``module.param_name`` with a :class:`torch.distributed._sharded_tensor.ShardedTensor` Args: module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. param_name (str): Name of the parameter of ``module`` that needs to be sharded. sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. .. warning:: Only :class:`torch.distributed._sharding_spec.ShardingSpec` is currently supported as the ``sharding_spec``. """ # Perform some validation first. if not isinstance(sharding_spec, ChunkShardingSpec): raise ValueError('Only ChunkShardingspec is supported.') if not hasattr(module, param_name): raise ValueError(f'module: {module} does not have parameter with name: {param_name}') tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') if not tensor.is_contiguous(): raise ValueError(f'param: {param_name} is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group() world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size with torch.cuda.device(tensor.device): dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index] ).clone().detach().contiguous() # Sync requires_grad to local_shard. local_shard.requires_grad = tensor.requires_grad # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg) # Manually set sharding_spec st._sharding_spec = sharding_spec # Replace param with ShardedTensor. # Need to delete the attribute first since param_name might be # torch.nn.Parameter and can't be replaced with ShardedTensor which is # not torch.nn.Parameter. delattr(module, param_name) # Now we can set the attribute appropriately. setattr(module, param_name, st)
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False): super(DistributedDataParallel, self).__init__() assert any((p.requires_grad for p in module.parameters())), ( "DistributedDataParallel is not needed when a module " "doesn't have any parameter that requires a gradient.") self.is_multi_device_module = len( {p.device for p in module.parameters()}) > 1 distinct_device_types = {p.device.type for p in module.parameters()} assert len(distinct_device_types) == 1, ( "DistributedDataParallel's input module must be on " "the same type of devices, but input module parameters locate in {}." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] if self.device_type == "cpu" or self.is_multi_device_module: assert not device_ids and not output_device, ( "DistributedDataParallel device_ids and output_device arguments " "only work with single-device GPU modules, but got " "device_ids {}, output_device {}, and module parameters {}." ).format(device_ids, output_device, {p.device for p in module.parameters()}) self.device_ids = None self.output_device = None else: # Use all devices by default for single-device GPU modules if device_ids is None: device_ids = _get_all_device_indices() self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) if output_device is None: output_device = device_ids[0] self.output_device = _get_device_index(output_device, True) if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device = list(self.module.parameters())[0].device self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True self.ddp_join_enabled = False if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. pass # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) # Sync params and buffers self._sync_params_and_buffers(authoritative_rank=0) self._ddp_init_helper()
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False): super(DistributedDataParallel, self).__init__() assert any((p.requires_grad for p in module.parameters())), ( "DistributedDataParallel is not needed when a module " "doesn't have any parameter that requires a gradient.") self.is_multi_device_module = len( {p.device for p in module.parameters()}) > 1 self.is_cuda = all( [p.device.type == 'cuda' for p in module.parameters()]) if not self.is_cuda or self.is_multi_device_module: assert not device_ids and not output_device, ( "DistributedDataParallel device_ids and output_device arguments " "only work with single-device CUDA modules, but got " "device_ids {}, output_device {}, and module parameters {}." ).format(device_ids, output_device, {p.device for p in module.parameters()}) self.device_ids = None self.output_device = None else: # Use all devices by default for single-device CUDA modules if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) if output_device is None: output_device = device_ids[0] self.output_device = _get_device_index(output_device, True) if self.is_multi_device_module: assert self.is_cuda, ( "DistributedDataParallel with multi-device module only works " "with CUDA devices, but module parameters locate in {}." ).format({p.device for p in module.parameters()}) if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. pass MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * MB) # reduction bucket size self.bucket_bytes_cap = int(bucket_cap_mb * MB) # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._distributed_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def _shard_tensor(tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.Tensor`, it shards that tensor according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. Args: tensor (:class:`torch.Tensor`): Tensor needs to be sharded. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Returns: A :class:`ShardedTensor` sharded from the given tensor. .. warning:: Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is currently supported as the ``sharding_spec``. """ if not isinstance(sharding_spec, ChunkShardingSpec): raise NotImplementedError('Only ChunkShardingspec is supported.') if not tensor.is_contiguous(): raise ValueError('input tensor is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_sizes[ sharding_spec.dim], # type: ignore[union-attr, index] ).clone().detach().contiguous() # Sync requires_grad to local_shard. local_shard.requires_grad = tensor.requires_grad # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] return ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, weight_decay=0., amsgrad=False, dtype=torch.float32, grad_sync_dtype=None, param_sync_dtype=None, device='cuda', process_group=None, distributed_process_group=None, redundant_process_group=None, model_parallel=False, model_parallel_rank=0, average_grad_sync=True, overlap_grad_sync=True, bucket_cap_mb=15, pipeline_size=2, fused_grad_copy=False, max_grad_norm=0., ): defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(DistributedFusedAdam, self).__init__(params, defaults) # Adam options if amsgrad: raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') # Datatype options if grad_sync_dtype is None: grad_sync_dtype = dtype if param_sync_dtype is None: param_sync_dtype = grad_sync_dtype valid_dtypes = [ (torch.float32, torch.float16, torch.float16), (torch.float32, torch.float32, torch.float32), ] if (dtype, grad_sync_dtype, param_sync_dtype) not in valid_dtypes: raise RuntimeError( 'Invalid dtypes for DistributedFusedAdam ' f'(dtype={dtype}, ' f'grad_sync_dtype={grad_sync_dtype}, ' f'param_sync_dtype={param_sync_dtype}))') if device != 'cuda': raise RuntimeError('DistributedFusedAdam only supports GPU') self.dtype = dtype self.grad_sync_dtype = grad_sync_dtype self.param_sync_dtype = param_sync_dtype self.device = device # Process groups self.world_process_group = ( _get_default_group() if process_group is None else process_group ) self.distributed_process_group = ( self.world_process_group if distributed_process_group is None else distributed_process_group ) self.redundant_process_group = redundant_process_group self.world_size = torch.distributed.get_world_size(self.world_process_group) self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group) self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group) self.redundant_size = ( 1 if self.redundant_process_group is None else torch.distributed.get_world_size(self.redundant_process_group) ) if (self.world_size != self.distributed_size * self.redundant_size): raise RuntimeError( 'Invalid process group configuration ' f'(world process group size = {self.world_size}, ' f'distributed process group size = {self.distributed_size}, ' f'redundant process group size = {self.redundant_size})' ) self.model_parallel = model_parallel self.model_parallel_rank = model_parallel_rank # Grad sync options if fused_grad_copy: _params = list(self.parameters()) if (_params and any(p.dtype != self.grad_sync_dtype for p in _params) and any(p.device != self.device for p in _params)): raise RuntimeError( 'Attempted to use fused gradient copy in DistributedFusedAdam, ' 'but parameters do not all have expected ' f'dtype ({self.grad_sync_dtype}) and device ({self.device})' ) self.average_grad_sync = average_grad_sync self.overlap_grad_sync = overlap_grad_sync self.pipeline_size = pipeline_size self.fused_grad_copy = fused_grad_copy # Grad clipping options self.max_grad_norm = max_grad_norm # Determine bucket sizes dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 self.alignment = 128 // dtype_size bucket_size = 1024*1024*bucket_cap_mb / dtype_size shard_size = bucket_size / self.distributed_size shard_size = (int(shard_size) // self.alignment) * self.alignment shard_size = max(shard_size, self.alignment) bucket_size = shard_size * self.distributed_size self.bucket_size = bucket_size self.shard_size = shard_size # Load CUDA kernels global fused_adam_cuda, distributed_adam_cuda fused_adam_cuda = importlib.import_module("fused_adam_cuda") distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") # Optimizer state self.state['buckets'] = [] self.state['step'] = 0 # Objects for gradient synchronization self._grads_generated = set() self._grads_to_copy = [] self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] # Check if collectives have no_copy option self._reduce_scatter_no_copy = ( 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args ) self._all_gather_no_copy = ( 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args ) # Attach hooks for gradient synchronization self._register_post_backward_hooks()
def __init__(self, module: nn.Module, process_group: Optional[ProcessGroup] = None, cpu_offload: Optional[CPUOffload] = None): torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() self.process_group = process_group or _get_default_group() self.rank = self.process_group.rank() self.world_size = self.process_group.size() # device for computation, if module is on GPU, use module.device; # if module is on CPU, use current device; self.compute_device = _get_default_cuda_device(module) self.compute_dtype = _get_data_type(module) # Free full params and keep shard only after forward self.reshard_after_forward = True # setting two factors to avoid underflow and overflow self.gradient_predivide_factor: float = self._get_gradient_predivide_factor( self.world_size) self.gradient_postdivide_factor: float = ( self.world_size / self.gradient_predivide_factor) self.numel_padded_per_param: List[int] = [] self.cpu_offload = cpu_offload or CPUOffload() # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. params = [] for param_name, param in module.named_parameters(): if not hasattr(param, "_is_sharded"): params.append(param) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( module, param_list=params) del module # free original module in case it helps garbage collection if self._fsdp_wrapped_module.flat_param is not None: self.params = [self._fsdp_wrapped_module.flat_param] else: self.params = [] # Shard module parameters in place self._shard_parameters() # Make sure all parameters are sharded. for n, p in self.named_parameters(): if not hasattr(p, "_is_sharded"): raise RuntimeError( f"found unsharded parameter: {n} ; {p.size()}") self._reset_lazy_init() # Enum to indicate if we're in the forward/backward pass, idle, etc. self.training_state = TrainingState_.IDLE # Flag to guard against preparing gradients multiple times per backward pass. self._pre_backward_hook_has_run = False # If specified, offload parameter shard to CPU. if self.cpu_offload.offload_params: for p in self.params: self._offload_to_cpu(p)
def __init__( self, sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format, process_group=None, ): self._rpc_initialized = False self._sharded_tensor_id = None if rpc._is_current_rpc_agent_set(): # Validate PG and RPC ranks match. pg_rank = dist.get_rank() rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( f'Default ProcessGroup and RPC ranks must be ' f'the same for ShardedTensor, found process group rank: ' f'{pg_rank} and RPC rank: {rpc_rank}') if layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') if memory_format != torch.contiguous_format: raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) self._sharding_spec = sharding_spec self._dims = list(size) self._process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) if distributed_c10d._rank_not_in_group(self._process_group): raise ValueError( f'Global rank: {dist.get_rank()} not part of process group') self._local_shards: List[Shard] = [] self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} self._sharding_metadata: List[ShardMetadata] = [] if isinstance(self._sharding_spec, ChunkShardingSpec): self._init_chunked( dtype, layout, requires_grad, pin_memory, memory_format, ) elif isinstance(self._sharding_spec, EnumerableShardingSpec): self._init_enumerable( dtype, layout, requires_grad, pin_memory, memory_format, ) else: raise ValueError( f'Unsupported sharding_spec: {self._sharding_spec}') with _sharded_tensor_lock: global _sharded_tensor_current_id, _sharded_tensor_map self._sharded_tensor_id = _sharded_tensor_current_id _sharded_tensor_map[self._sharded_tensor_id] = self _sharded_tensor_current_id += 1 # Initialize RPC if available. if rpc._is_current_rpc_agent_set(): self._init_rpc()
def __setstate__(self, state): # If serializable, then the process group should be the default one self.process_group = _get_default_group() self.check_previous_reduction = False super(RandomKSparsifiedDDP, self).__setstate__(state) self._ddp_init_helper()
def __init__(self, params, lr=1e-3, bias_correction = True, grad_averaging=True, betas=(0.9, 0.999), eps=1e-8, weight_decay=0., max_grad_norm=0., adam_w_mode=True, use_nvlamb=False, step_supports_amp_scaling=True, overlap_reductions=True, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, e5m2_allgather=False, verbose=False): defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm) super(DistributedFusedLAMB, self).__init__(params, defaults) global fused_adam_cuda, distributed_lamb_cuda fused_adam_cuda = importlib.import_module("fused_adam_cuda") distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda") self._overflow_buf = torch.cuda.IntTensor([0]) self._has_overflow = False self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self._grad_averaging = grad_averaging self._adam_w_mode = 1 if adam_w_mode else 0 self._use_nvlamb = use_nvlamb self._step_supports_amp_scaling = step_supports_amp_scaling self._is_accumulation_step = False self._last_step = False self._overlap_reductions = overlap_reductions self._global_scale = None self._num_blocks = dwu_num_blocks self._num_chunks = dwu_num_chunks self._e5m2_allgather = e5m2_allgather self._verbose = verbose self._L2_grad_norm = None self._current_process_group = c10d._get_default_group() self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._world_size = torch.distributed.get_world_size() self._num_groups = self._world_size // self._group_size self._rank_in_group = torch.distributed.get_rank() % self._group_size self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda') self._resume_from_checkpoint = False self._step = torch.cuda.IntTensor([0]) # Master weight, moment, gradient buffers self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None import inspect assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" self._num_rs_pg = dwu_num_rs_pg self._num_ar_pg = dwu_num_ar_pg self._num_ag_pg = dwu_num_ag_pg if self._num_groups > 1: self._ar_pg = [] for dev_i in range(self._group_size): ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] for i in range(self._num_ar_pg): if self._verbose: print(f"creating new group {i}: {ranks}") grp = torch.distributed.new_group(ranks=ranks) if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER: if self._verbose: print(f"group {i}: init barrier (device: {torch.cuda.current_device()})") torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()]) if self._verbose: print(f"created new group {i}") if torch.distributed.get_rank() in ranks: self._ar_pg.append(grp) self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] #for ar_pg in self._ar_pg: # torch.distributed.all_reduce(self._overflow_buf,group=ar_pg) rs_ranks = [] for group_i in range(self._num_groups): rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) self._rs_pg = [] for group_i in range(self._num_groups): ranks = rs_ranks[group_i] for i in range(self._num_rs_pg): grp = torch.distributed.new_group(ranks=ranks) if torch.distributed.get_rank() in ranks: self._rs_pg.append(grp) l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) if torch.distributed.get_rank() in ranks: self._l2_grad_norm_pg = l2_grad_norm_pg #torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg) self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] #for rs_pg in self._rs_pg: # torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) if self._num_ag_pg == 0: self._ag_pg = self._rs_pg self._ag_st = self._rs_st self._num_ag_pg = self._num_rs_pg else: self._ag_pg = [] for group_i in range(self._num_groups): ranks = rs_ranks[group_i] for i in range(self._num_ag_pg): grp = torch.distributed.new_group(ranks=ranks) if torch.distributed.get_rank() in ranks: self._ag_pg.append(grp) self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] #for ag_pg in self._ag_pg: # torch.distributed.all_reduce(self._overflow_buf,group=ag_pg) self._l2_grad_norm_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream() self._step.record_stream(self._completion_st) self._reductions_works = [None]*self._num_blocks self._allgather_works = [None]*self._num_blocks self._one = torch.cuda.IntTensor([1]) self._first_step = True self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False self._param_order = self.AtomicCounter()
def __init__( self, module: nn.Module, process_group: Optional[ProcessGroup] = None, cpu_offload: Optional[CPUOffload] = None, fsdp_auto_wrap_policy: Optional[Callable] = None, ): torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() # if fsdp_auto_wrap_policy is specified, submodules should not be # already wrapped, otherwise we'd attempt to double wrap them resulting # in errors. if fsdp_auto_wrap_policy is not None: self._check_wrapped( module, check_fn=lambda mod: not isinstance(mod, FullyShardedDataParallel), err_fn=lambda mod: f"Expected {mod} to NOT be FullyShardedDataParallel if auto_wrap is enabled.", ) # TODO: refactor recursive_wrap so that it is not dependent on # ConfigAutoWrap. config_auto_wrap = ConfigAutoWrap( auto_wrap_policy=fsdp_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel # type: ignore[arg-type] ) with config_auto_wrap: assert ConfigAutoWrap.in_autowrap_context assert ConfigAutoWrap.wrapper_cls == FullyShardedDataParallel assert ConfigAutoWrap.auto_wrap_policy == fsdp_auto_wrap_policy # This will only wrap the children, and then constructor will # run for root module. ConfigAutoWrap.recursive_wrap( module, auto_wrap_policy=fsdp_auto_wrap_policy, # Note that we have the recursive_wrap skip wrapping for # the outermost (this) module otherwise it will result in a # double-wrap causing issues. only_wrap_children=True, # FSDP arguments follow. process_group=process_group, cpu_offload=cpu_offload, # Note that recursive_wap should not call FSDP with wrapping # enabled, as this recursive call handles all wrapping, # including for nested children. fsdp_auto_wrap_policy=None, ) assert not ConfigAutoWrap.in_autowrap_context self.process_group = process_group or _get_default_group() self.rank = self.process_group.rank() self.world_size = self.process_group.size() # device for computation, if module is on GPU, use module.device; # if module is on CPU, use current device; self.compute_device = _get_default_cuda_device(module) # Free full params and keep shard only after forward self.reshard_after_forward = True # setting two factors to avoid underflow and overflow self.gradient_predivide_factor: float = self._get_gradient_predivide_factor( self.world_size) self.gradient_postdivide_factor: float = ( self.world_size / self.gradient_predivide_factor) self.numel_padded_per_param: List[int] = [] self.cpu_offload = cpu_offload or CPUOffload() # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. params = [] for param_name, param in module.named_parameters(): if not hasattr(param, "_is_sharded"): params.append(param) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( module, param_list=params) del module # free original module in case it helps garbage collection if self._fsdp_wrapped_module.flat_param is not None: self.params = [self._fsdp_wrapped_module.flat_param] else: self.params = [] # Shard module parameters in place self._shard_parameters() # Make sure all parameters are sharded. for n, p in self.named_parameters(): if not hasattr(p, "_is_sharded"): raise RuntimeError( f"found unsharded parameter: {n} ; {p.size()}") self._reset_lazy_init() # Enum to indicate if we're in the forward/backward pass, idle, etc. self.training_state = TrainingState_.IDLE # Flag to guard against preparing gradients multiple times per backward pass. self._pre_backward_hook_has_run = False # If specified, offload parameter shard to CPU. if self.cpu_offload.offload_params: for p in self.params: self._offload_to_cpu(p)
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import ( ShardedTensor ) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned() ) current_rank = dist.get_rank(process_group) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] local_tensor = None local_metadata = None tensors_to_scatter = [None] * dist.get_world_size(process_group) sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] chunks = len(self.placements) split_size = get_split_size(sharding_dim_size, chunks) scatter_shape = list(tensor.size()) scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrow_tensor(tensor, shard_meta) if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] # for the last shard that might be smaller to other shards # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape) else: tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() tensors_to_scatter[rank] = tensor_to_scatter if current_rank == rank: local_tensor = torch.empty( scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build # the metadata list in a correct way. assert local_tensor is not None assert local_metadata is not None # Scatter the shards to all ranks in the pg # scatter takes the global rank as ``src`` src_for_scatter = src_rank if process_group is not None and process_group is not distributed_c10d._get_default_group(): src_for_scatter = distributed_c10d._get_global_rank(process_group, src_for_scatter) dist.scatter( local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_for_scatter, group=process_group ) if list(local_tensor.size()) != local_metadata.shard_sizes: # detach again after receiving to ensure local shards remain a leaf node local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, tensor_meta, process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st