def test_scatter_tensor(): ab = torch.zeros(2, 1) a, b = scatter(ab, chunks=2, device=torch.device('cpu')) assert a.size() == (1, 1) assert b.size() == (1, 1)
def push_input(self, input: TensorOrTensors, in_queue: PriorityQueue, ) -> int: """Pushes chunked inputs to the first partition.""" # Divide a mini-batch into micro-batches. inputs = scatter(input, chunks=self.chunks, device=self.in_device) # The number of inputs might be smaller than the number of chunks. num_inputs = len(inputs) for i, _input in enumerate(inputs): # NOTE(sublee): 'except_last' is the defualt option. Compare it first. if self.checkpoint == 'except_last': checkpoint = (i < self.chunks-1) elif self.checkpoint == 'always': checkpoint = True elif self.checkpoint == 'never': checkpoint = False msg = Message(i, (_input, checkpoint)) in_queue.put(msg) close = Message(num_inputs, None) in_queue.put(close) return num_inputs
def _push_input( self, input: TensorOrTensors, in_queue: PriorityQueue, ) -> int: """Pushes chunked inputs to the first partition.""" # Divide a mini-batch into micro-batches. in_device = self.devices[0] inputs = scatter(input, chunks=self.chunks, device=in_device) # The number of inputs might be smaller than the number of chunks. num_inputs = len(inputs) for i, _input in enumerate(inputs): # NOTE(sublee): 'except_last' is the defualt option. Compare it first. if self.checkpoint == 'except_last': checkpoint = (i < self.chunks - 1) elif self.checkpoint == 'always': checkpoint = True elif self.checkpoint == 'never': checkpoint = False # Every partition should track the current micro-batch. A # micro-batch lane can be identified its detached leaf tensor. leaf = (_input[0] if isinstance(_input, tuple) else _input).detach() msg = Message(i, (_input, leaf, checkpoint)) in_queue.put(msg) close = Message(num_inputs, None) in_queue.put(close) return num_inputs
def test_scatter_tensor(): ab = torch.zeros(2, 1) a, b = scatter(ab, chunks=2) assert a.tensor.size() == (1, 1) assert b.tensor.size() == (1, 1)
def test_scatter_tuple(): ab = (torch.zeros(2, 1), torch.zeros(4, 2)) a, b = scatter(ab, chunks=2) assert a.tensors[0].size() == (1, 1) assert b.tensors[0].size() == (1, 1) assert a.tensors[1].size() == (2, 2) assert b.tensors[1].size() == (2, 2)
def test_scatter_tuple(): ab = (torch.zeros(2, 1), torch.zeros(4, 2)) a, b = scatter(ab, chunks=2, device=torch.device('cpu')) assert isinstance(a, tuple) assert isinstance(b, tuple) assert a[0].size() == (1, 1) assert b[0].size() == (1, 1) assert a[1].size() == (2, 2) assert b[1].size() == (2, 2)
def test_default_device_index(): default_cuda = torch.device('cuda') assert default_cuda.index is None x = torch.rand(2, 1) a, b = scatter(x, chunks=2, device=default_cuda) y = gather([a, b], device=default_cuda) assert a.is_cuda assert b.is_cuda assert y.is_cuda
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore """:class:`GPipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a :class:`~torch.Tensor` or a tuple of tensors. This restriction is applied at partition boundaries too. Args: input (torch.Tensor or tensors): input mini-batch Returns: tensor or tensors: output mini-batch Raises: TypeError: input is not a tensor or tensors. """ microbatch.check(input) if not self.devices: # Empty sequential module is not illegal. return input # Divide a mini-batch into micro-batches. batches = microbatch.scatter(input, self.chunks) # Separate CUDA streams for copy. copy_streams = self._ensure_copy_streams() # The micro-batch index where the checkpointing stops. if self.training: checkpoint_stop = { 'always': self.chunks, 'except_last': self.chunks - 1, 'never': 0, }[self.checkpoint] else: checkpoint_stop = 0 # Run pipeline parallelism. pipeline = Pipeline(batches, self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) pipeline.run() # Merge the micro-batches into one mini-batch. output = microbatch.gather(batches) return output