def init_dist_connection( cluster_environment: "pl.plugins.environments.ClusterEnvironment", torch_distributed_backend: str, global_rank: Optional[int] = None, world_size: Optional[int] = None, **kwargs: Any, ) -> None: """Utility function to initialize distributed connection by setting env variables and initializing the distributed process group. Args: cluster_environment: ``ClusterEnvironment`` instance torch_distributed_backend: backend to use (includes `nccl` and `gloo`) global_rank: rank of the current process world_size: number of processes in the group kwargs: kwargs for ``init_process_group`` Raises: RuntimeError: If ``torch.distributed`` is not available """ if not torch.distributed.is_available(): raise RuntimeError( "torch.distributed is not available. Cannot initialize distributed process group" ) if torch.distributed.is_initialized(): log.debug("torch.distributed is already initialized. Exiting early") return global_rank = global_rank if global_rank is not None else cluster_environment.global_rank( ) world_size = world_size if world_size is not None else cluster_environment.world_size( ) os.environ["MASTER_ADDR"] = cluster_environment.main_address os.environ["MASTER_PORT"] = str(cluster_environment.main_port) log.info( f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" ) torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) # on rank=0 let everyone know training is starting new_rank_zero_info( f"{'-' * 100}\n" f"distributed_backend={torch_distributed_backend}\n" f"All distributed processes registered. Starting with {world_size} processes\n" f"{'-' * 100}\n")
def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: """Function to reduce the tensors from several ddp processes to one main process. Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: reduced value """ divide_by_world_size = False if group is None: group = torch.distributed.group.WORLD op: Optional[ReduceOp] if isinstance(reduce_op, str): if reduce_op.lower() in ("avg", "mean"): op = ReduceOp.SUM divide_by_world_size = True else: op = getattr(ReduceOp, reduce_op.upper()) else: op = reduce_op # WA for HPU. HPU doesn't support Long types, forcefully set it to float if _HPU_AVAILABLE: is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" if is_hpu_backend: if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"): new_rank_zero_info( "Long tensor unsupported on HPU, casting to float") result = result.float() # sync all processes before reduction torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=op, group=group, async_op=False) if divide_by_world_size: result = result / torch.distributed.get_world_size(group) return result
def rank_zero_info(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6" " and will be removed in v1.8." " Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead." ) return new_rank_zero_info(*args, **kwargs)
def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: DDP model ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future This callable function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. ddp_comm_wrapper: communication hook wrapper to support a communication hook such as FP16 compression as wrapper, which could be combined with ddp_comm_hook .. warning :: DDP communication hook needs pytorch version at least 1.8.0 .. warning :: DDP communication wrapper needs pytorch version at least 1.9.0 Post-localSGD hook needs pytorch version at least 1.9.0 Examples: >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP ... default_hooks as default, ... powerSGD_hook as powerSGD, ... post_localSGD_hook as post_localSGD, ... ) >>> >>> # fp16_compress_hook for compress gradients >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_hook=default.fp16_compress_hook, ... ) >>> >>> # powerSGD_hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ) >>> >>> # post_localSGD_hook >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... state=post_localSGD.PostLocalSGDState( ... process_group=None, ... subgroup=subgroup, ... start_localSGD_iter=1_000, ... ), ... ddp_comm_hook=post_localSGD.post_localSGD_hook, ... ) >>> >>> # fp16_compress_wrapper combined with other communication hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ddp_comm_wrapper=default.fp16_compress_wrapper, ... ) """ from pytorch_lightning.utilities import rank_zero_warn if not _TORCH_GREATER_EQUAL_1_8: rank_zero_warn( "Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0." ) return if ddp_comm_hook is None: return # inform mypy that ddp_comm_hook is callable ddp_comm_hook: Callable = ddp_comm_hook if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: rank_zero_warn( "Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0." ) else: new_rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) rank_zero_debug( f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)