def test_gather_tensors(): a = torch.zeros(1, 1) b = torch.zeros(1, 1) ab = gather([Batch(a), Batch(b)]) assert ab.size() == (2, 1)
def test_gather_tensors(): a = torch.zeros(1, 1) b = torch.zeros(1, 1) ab = gather([a, b], device=torch.device('cpu')) assert ab.size() == (2, 1)
def test_gather_tuples(): a = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2)) ab = gather([a, b], device=torch.device('cpu')) assert isinstance(ab, tuple) assert ab[0].size() == (2, 1) assert ab[1].size() == (4, 2)
def test_gather_tuples(): a = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2)) ab = gather([Batch(a), Batch(b)]) assert isinstance(ab, tuple) assert ab[0].size() == (2, 1) assert ab[1].size() == (4, 2)
def test_default_device_index(): default_cuda = torch.device('cuda') assert default_cuda.index is None x = torch.rand(2, 1) a, b = scatter(x, chunks=2, device=default_cuda) y = gather([a, b], device=default_cuda) assert a.is_cuda assert b.is_cuda assert y.is_cuda
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore """:class:`GPipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a :class:`~torch.Tensor` or a tuple of tensors. This restriction is applied at partition boundaries too. Args: input (torch.Tensor or tensors): input mini-batch Returns: tensor or tensors: output mini-batch Raises: TypeError: input is not a tensor or tensors. """ microbatch.check(input) if not self.devices: # Empty sequential module is not illegal. return input # Divide a mini-batch into micro-batches. batches = microbatch.scatter(input, self.chunks) # Separate CUDA streams for copy. copy_streams = self._ensure_copy_streams() # The micro-batch index where the checkpointing stops. if self.training: checkpoint_stop = { 'always': self.chunks, 'except_last': self.chunks - 1, 'never': 0, }[self.checkpoint] else: checkpoint_stop = 0 # Run pipeline parallelism. pipeline = Pipeline(batches, self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) pipeline.run() # Merge the micro-batches into one mini-batch. output = microbatch.gather(batches) return output
def _pull_output( self, num_inputs: int, in_queue: PriorityQueue, out_queue: PriorityQueue, ) -> Tensor: """Collects and concatenates chunked outputs from the last partition. If an exception from a parititon is detected, all workers are closed and the exception is re-raised. Raises: Exception: any exception from a partition """ # All worker threads will be closed when receiving this message. close = Message(-1, None) outputs = [] for _ in range(num_inputs): msg = out_queue.get() out_queue.task_done() if msg.i == -1: # Close worker threads immediately. in_queue.put(close) out_queue.get() # Raise the exception from a partition. exc_info = msg.payload raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) output, _ = msg.payload outputs.append(output) # Indicate the end of micro-batches. in_queue.put(close) # Merge the micro-batches into one mini-batch. out_device = self.devices[-1] output = gather(outputs, device=out_device) # Wait until the last worker thread closes. out_queue.get() return output
def _pull_output( self, num_inputs: int, in_queue: PriorityQueue, out_queue: PriorityQueue, ) -> Tensor: """Collects and concatenates chunked outputs from the last partition. If an exception from a parititon is detected, all workers are closed and the exception is re-raised. Raises: Exception: any exception from a partition """ outputs = [] for _ in range(num_inputs): msg = out_queue.get() out_queue.task_done() if msg.i == -1: # Close worker threads immediately. close = Message(-1, None) in_queue.put(close) out_queue.get() # Raise the exception from a partition. exc_info = msg.payload raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) output, _, _ = msg.payload outputs.append(output) out_device = self.devices[-1] output = gather(outputs, device=out_device) out_queue.get() return output