예제 #1
0
 def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
     # FIXME(handle nowait)
     if nowait:
         raise QueueEmpty
     tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device)
     torch.cuda.current_stream().synchronize()
     torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group())
     torch.cuda.current_stream().synchronize()
     return tensor_to_pyobject(tensor)
예제 #2
0
 def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
     tensors = message.tensors
     message.tensors = tuple()
     torch.cuda.current_stream().synchronize()
     if not skip_header:
         message.tensor_shapes = [t.size() for t in tensors]
         message.tensor_dtypes = [t.dtype for t in tensors]
         torch.distributed.send(
             pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(),
             message.dest,
             tag=message.queue_name,
             group=get_pipeline_parallel_group(),
         )
     for index, t in enumerate(tensors):
         if t.device.type == "cpu":
             t = t.cuda()
         torch.distributed.send(
             t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group()
         )
예제 #3
0
    def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
        torch.cuda.current_stream().synchronize()

        message_tensors = []
        for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)):
            t = torch.empty(*shape, dtype=dtype, device=self.input_device)
            torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group())
            message_tensors.append(t)

        message.tensors = tuple(message_tensors)

        torch.cuda.current_stream().synchronize()
        return message
예제 #4
0
    def __init__(
        self,
        module: Union[nn.Sequential, List[LazyModule]],
        balance: Iterable[int],
        *,
        group: Optional[torch.distributed.ProcessGroup] = None,
        worker_map: Optional[Dict[int, str]] = None,
        input_device: Union[None, int, str, torch.device] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        if get_model_parallel_world_size() > 1:
            self.pipelined_backward = True
        else:
            self.pipelined_backward = False

        self.balance = list(balance)
        verify_module(module)
        check_balance(module, self.balance)

        self.chunks = chunks
        self.checkpoint = checkpoint
        self.pipeline: Optional[MultiProcessPipeline]
        self.lock = threading.Lock()

        self.worker_map = worker_map
        self.input_device = input_device

        self.group: torch.distributed.ProcessGroup
        if group is None:
            self.group = get_pipeline_parallel_group()
        else:
            self.group = group

        if self.group.size() < len(self.balance):
            raise IndexError(
                f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
                f" {len(self.balance)})")

        self._skip_layout = SkipLayout(len(module), {})  # FIXME(tom)

        rank = self.group.rank()
        self.final_stage = rank == len(self.balance) - 1
        if rank >= len(self.balance):
            warnings.warn("More ranks than partitions, some ranks unused")
            self.partition = nn.Sequential()
            self.pipeline = None
        else:
            self.partition = self.instantiate_partition(
                module, self.balance, self.group)
            if deferred_batch_norm:
                self.partitition = DeferredBatchNorm.convert_deferred_batch_norm(
                    self.partition, chunks)
            self.add_module(str(0), self.partition)
            self.create_pipeline()

        del module
예제 #5
0
 def set_main_rpc_process(self):
     self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0
예제 #6
0
    def __init__(
        self,
        module: Union[nn.Sequential, List[LazyModule]],
        balance: Optional[Iterable[int]] = None,
        *,
        group: Optional[torch.distributed.ProcessGroup] = None,
        worker_map: Optional[Dict[int, str]] = None,
        input_device: Union[None, int, str, torch.device] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
        pipelined_backward: bool = None,
        retain_graph: bool = False,
    ) -> None:
        super().__init__()

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if balance is None:
            raise ValueError(recommend_auto_balance("balance is required"))
        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)

        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.
        if isinstance(module, nn.Sequential):
            verify_skippables(module)

        self.chunks = chunks
        self.checkpoint = checkpoint
        self.pipelined_backward = pipelined_backward
        self.retain_graph = retain_graph
        self.pipeline: Optional[MultiProcessPipeline]
        self.lock = threading.Lock()

        self.worker_map = worker_map
        self.input_device = input_device

        self.group: torch.distributed.ProcessGroup
        if group is None:
            self.group = get_pipeline_parallel_group()
        else:
            self.group = group

        self.balance = list(balance)

        if self.group.size() < len(self.balance):
            raise IndexError(
                f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
                f" {len(self.balance)})")
        try:
            rank = self.group.rank()
            if rank >= len(self.balance):
                warnings.warn("More ranks than partitions, some ranks unused")
                self.partitions: List[ModuleWrapper] = []
            else:
                self.partitions = self.instantiate_partition(
                    module, balance, self.group)
                if deferred_batch_norm:
                    for part in self.partitions:
                        part.module = DeferredBatchNorm.convert_deferred_batch_norm(
                            part.module, chunks)
                for name, part in enumerate(self.partitions):
                    self.add_module(str(name), part.module)
            if isinstance(module, nn.Sequential):
                local_partitions = split_module(module, balance)
                self._skip_layout = inspect_skip_layout(local_partitions)
            else:
                self._skip_layout = SkipLayout(len(module), {})  # FIXME(tom)

        except BalanceError as exc:
            raise ValueError(recommend_auto_balance(str(exc)))

        rank = self.group.rank()
        if rank >= len(self.balance):
            self.pipeline = None
            self.final_stage = False
        else:
            self.final_stage = rank == len(self.balance) - 1

            self.create_pipeline()
            del module
        if self.pipelined_backward is None:
            if get_model_parallel_world_size() > 1:
                self.pipelined_backward = True
            else:
                self.pipelined_backward = False
예제 #7
0
    def __init__(
        self,
        module: Union[nn.Sequential, ListOfLazyModules],
        balance: Optional[Iterable[int]] = None,
        *,
        style: PipelineStyle = PipelineStyle.SingleProcess,
        devices: Optional[Devices] = None,
        group: Optional[torch.distributed.ProcessGroup] = None,
        worker_map: Optional[Dict[int, str]] = None,
        input_device: Union[None, int, str, torch.device] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
        pipelined_backward: bool = None,
        retain_graph: bool = False,
        loss_fn: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if balance is None:
            raise ValueError(recommend_auto_balance("balance is required"))
        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)

        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.
        if isinstance(module, nn.Sequential):
            verify_skippables(module)

        self.chunks = chunks
        self.checkpoint = checkpoint
        self.pipelined_backward = pipelined_backward
        self.retain_graph = retain_graph
        self.pipeline: Optional[Pipeline]
        self.loss_fn = loss_fn
        self.lock = threading.Lock()

        self.group = group
        self.worker_map = worker_map
        self.input_device = input_device

        self._copy_streams: List[List[AbstractStream]] = []

        # The micro-batch index where the checkpointing stops.
        checkpoint_stop = {
            "always": self.chunks,
            "except_last": self.chunks - 1,
            "never": 0
        }[self.checkpoint]

        if style is PipelineStyle.SingleProcess:
            module = cast(nn.Sequential, module)
            if deferred_batch_norm:
                module = DeferredBatchNorm.convert_deferred_batch_norm(
                    module, chunks)

            if input_device is not None:
                raise ValueError(
                    "'input_device' argument only applies to 'PipelineStyle.MultiProcess'"
                )

            if devices is None:
                devices = range(torch.cuda.device_count())

            devices = [torch.device(d) for d in devices]
            devices = cast(List[torch.device], devices)

            try:
                self.partitions, self.balance, self.devices = split_module(
                    module, balance, devices)
            except BalanceError as exc:
                raise ValueError(recommend_auto_balance(str(exc)))
            verify_splitting(module, self.partitions, self.balance,
                             self.devices)

            self._skip_layout = inspect_skip_layout(self.partitions)

            # Separate CUDA streams for copy.
            copy_streams = self._ensure_copy_streams()
            if self.pipelined_backward is None:
                self.pipelined_backward = False
            self.pipeline = Pipeline(
                self.partitions,
                self.devices,
                copy_streams,
                self._skip_layout,
                checkpoint_stop,
                style=style,
            )

        elif style in [
                PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule
        ]:

            if self.group is None:
                self.group = get_pipeline_parallel_group()
            assert self.group

            if devices is not None:
                raise ValueError(
                    "'devices' argument only applies to 'PipelineStyle.SingleProcess'"
                )

            self.balance = list(balance)

            if self.group.size() < len(self.balance):
                raise IndexError(
                    f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
                    f" {len(self.balance)})")
            try:
                rank = self.group.rank()
                if rank >= len(self.balance):
                    warnings.warn(
                        "More ranks than partitions, some ranks unused")
                    self.mp_partitions: List[ModuleWrapper] = []
                else:
                    self.mp_partitions = instantiate_partition(
                        module, balance, self.group, style)
                    if deferred_batch_norm:
                        for part in self.mp_partitions:
                            part.module = DeferredBatchNorm.convert_deferred_batch_norm(
                                part.module, chunks)
                    for name, part in enumerate(self.mp_partitions):
                        self.add_module(str(name), part.module)
                self.devices = None
                if isinstance(module, nn.Sequential):
                    local_partitions, _, _ = split_module(
                        module, balance, None)
                    self._skip_layout = inspect_skip_layout(local_partitions)
                else:
                    self._skip_layout = SkipLayout(len(module),
                                                   {})  # FIXME(tom)

            except BalanceError as exc:
                raise ValueError(recommend_auto_balance(str(exc)))

            rank = self.group.rank()
            if rank >= len(self.balance):
                self.pipeline = None
                self.final_stage = False
            else:
                self.final_stage = rank == len(self.balance) - 1
                assert loss_fn is None or self.final_stage

                self.pipeline = Pipeline(
                    cast(List[nn.Sequential], self.mp_partitions),
                    None,
                    None,
                    self._skip_layout,
                    checkpoint_stop,
                    style=style,
                    group=self.group,
                    worker_map=self.worker_map,
                    input_device=self.input_device,
                    final_stage=self.final_stage,
                )
                del module
            if self.pipelined_backward is None:
                if get_model_parallel_world_size() > 1:
                    self.pipelined_backward = True
                else:
                    self.pipelined_backward = False