def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: # FIXME(handle nowait) if nowait: raise QueueEmpty tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device) torch.cuda.current_stream().synchronize() torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) torch.cuda.current_stream().synchronize() return tensor_to_pyobject(tensor)
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: tensors = message.tensors message.tensors = tuple() torch.cuda.current_stream().synchronize() if not skip_header: message.tensor_shapes = [t.size() for t in tensors] message.tensor_dtypes = [t.dtype for t in tensors] torch.distributed.send( pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(), message.dest, tag=message.queue_name, group=get_pipeline_parallel_group(), ) for index, t in enumerate(tensors): if t.device.type == "cpu": t = t.cuda() torch.distributed.send( t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group() )
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: torch.cuda.current_stream().synchronize() message_tensors = [] for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): t = torch.empty(*shape, dtype=dtype, device=self.input_device) torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) message_tensors.append(t) message.tensors = tuple(message_tensors) torch.cuda.current_stream().synchronize() return message
def __init__( self, module: Union[nn.Sequential, List[LazyModule]], balance: Iterable[int], *, group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, ) -> None: super().__init__() if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError( "checkpoint is not one of 'always', 'except_last', or 'never'") if get_model_parallel_world_size() > 1: self.pipelined_backward = True else: self.pipelined_backward = False self.balance = list(balance) verify_module(module) check_balance(module, self.balance) self.chunks = chunks self.checkpoint = checkpoint self.pipeline: Optional[MultiProcessPipeline] self.lock = threading.Lock() self.worker_map = worker_map self.input_device = input_device self.group: torch.distributed.ProcessGroup if group is None: self.group = get_pipeline_parallel_group() else: self.group = group if self.group.size() < len(self.balance): raise IndexError( f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" f" {len(self.balance)})") self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) rank = self.group.rank() self.final_stage = rank == len(self.balance) - 1 if rank >= len(self.balance): warnings.warn("More ranks than partitions, some ranks unused") self.partition = nn.Sequential() self.pipeline = None else: self.partition = self.instantiate_partition( module, self.balance, self.group) if deferred_batch_norm: self.partitition = DeferredBatchNorm.convert_deferred_batch_norm( self.partition, chunks) self.add_module(str(0), self.partition) self.create_pipeline() del module
def set_main_rpc_process(self): self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0
def __init__( self, module: Union[nn.Sequential, List[LazyModule]], balance: Optional[Iterable[int]] = None, *, group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, pipelined_backward: bool = None, retain_graph: bool = False, ) -> None: super().__init__() chunks = int(chunks) checkpoint = str(checkpoint) if balance is None: raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError( "checkpoint is not one of 'always', 'except_last', or 'never'") verify_module(module) # Verify if the underlying skippable modules satisfy integrity. The # integrity can be verified before forward() because it is static. if isinstance(module, nn.Sequential): verify_skippables(module) self.chunks = chunks self.checkpoint = checkpoint self.pipelined_backward = pipelined_backward self.retain_graph = retain_graph self.pipeline: Optional[MultiProcessPipeline] self.lock = threading.Lock() self.worker_map = worker_map self.input_device = input_device self.group: torch.distributed.ProcessGroup if group is None: self.group = get_pipeline_parallel_group() else: self.group = group self.balance = list(balance) if self.group.size() < len(self.balance): raise IndexError( f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" f" {len(self.balance)})") try: rank = self.group.rank() if rank >= len(self.balance): warnings.warn("More ranks than partitions, some ranks unused") self.partitions: List[ModuleWrapper] = [] else: self.partitions = self.instantiate_partition( module, balance, self.group) if deferred_batch_norm: for part in self.partitions: part.module = DeferredBatchNorm.convert_deferred_batch_norm( part.module, chunks) for name, part in enumerate(self.partitions): self.add_module(str(name), part.module) if isinstance(module, nn.Sequential): local_partitions = split_module(module, balance) self._skip_layout = inspect_skip_layout(local_partitions) else: self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) rank = self.group.rank() if rank >= len(self.balance): self.pipeline = None self.final_stage = False else: self.final_stage = rank == len(self.balance) - 1 self.create_pipeline() del module if self.pipelined_backward is None: if get_model_parallel_world_size() > 1: self.pipelined_backward = True else: self.pipelined_backward = False
def __init__( self, module: Union[nn.Sequential, ListOfLazyModules], balance: Optional[Iterable[int]] = None, *, style: PipelineStyle = PipelineStyle.SingleProcess, devices: Optional[Devices] = None, group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, pipelined_backward: bool = None, retain_graph: bool = False, loss_fn: Optional[nn.Module] = None, ) -> None: super().__init__() chunks = int(chunks) checkpoint = str(checkpoint) if balance is None: raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError( "checkpoint is not one of 'always', 'except_last', or 'never'") verify_module(module) # Verify if the underlying skippable modules satisfy integrity. The # integrity can be verified before forward() because it is static. if isinstance(module, nn.Sequential): verify_skippables(module) self.chunks = chunks self.checkpoint = checkpoint self.pipelined_backward = pipelined_backward self.retain_graph = retain_graph self.pipeline: Optional[Pipeline] self.loss_fn = loss_fn self.lock = threading.Lock() self.group = group self.worker_map = worker_map self.input_device = input_device self._copy_streams: List[List[AbstractStream]] = [] # The micro-batch index where the checkpointing stops. checkpoint_stop = { "always": self.chunks, "except_last": self.chunks - 1, "never": 0 }[self.checkpoint] if style is PipelineStyle.SingleProcess: module = cast(nn.Sequential, module) if deferred_batch_norm: module = DeferredBatchNorm.convert_deferred_batch_norm( module, chunks) if input_device is not None: raise ValueError( "'input_device' argument only applies to 'PipelineStyle.MultiProcess'" ) if devices is None: devices = range(torch.cuda.device_count()) devices = [torch.device(d) for d in devices] devices = cast(List[torch.device], devices) try: self.partitions, self.balance, self.devices = split_module( module, balance, devices) except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) verify_splitting(module, self.partitions, self.balance, self.devices) self._skip_layout = inspect_skip_layout(self.partitions) # Separate CUDA streams for copy. copy_streams = self._ensure_copy_streams() if self.pipelined_backward is None: self.pipelined_backward = False self.pipeline = Pipeline( self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style, ) elif style in [ PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule ]: if self.group is None: self.group = get_pipeline_parallel_group() assert self.group if devices is not None: raise ValueError( "'devices' argument only applies to 'PipelineStyle.SingleProcess'" ) self.balance = list(balance) if self.group.size() < len(self.balance): raise IndexError( f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" f" {len(self.balance)})") try: rank = self.group.rank() if rank >= len(self.balance): warnings.warn( "More ranks than partitions, some ranks unused") self.mp_partitions: List[ModuleWrapper] = [] else: self.mp_partitions = instantiate_partition( module, balance, self.group, style) if deferred_batch_norm: for part in self.mp_partitions: part.module = DeferredBatchNorm.convert_deferred_batch_norm( part.module, chunks) for name, part in enumerate(self.mp_partitions): self.add_module(str(name), part.module) self.devices = None if isinstance(module, nn.Sequential): local_partitions, _, _ = split_module( module, balance, None) self._skip_layout = inspect_skip_layout(local_partitions) else: self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) rank = self.group.rank() if rank >= len(self.balance): self.pipeline = None self.final_stage = False else: self.final_stage = rank == len(self.balance) - 1 assert loss_fn is None or self.final_stage self.pipeline = Pipeline( cast(List[nn.Sequential], self.mp_partitions), None, None, self._skip_layout, checkpoint_stop, style=style, group=self.group, worker_map=self.worker_map, input_device=self.input_device, final_stage=self.final_stage, ) del module if self.pipelined_backward is None: if get_model_parallel_world_size() > 1: self.pipelined_backward = True else: self.pipelined_backward = False