def __init__( self, batches: List[Batch], partitions: List[nn.Sequential], devices: Optional[List[torch.device]] = None, copy_streams: Optional[List[List[AbstractStream]]] = None, skip_layout: Optional[SkipLayout] = None, checkpoint_stop: int = 0, ) -> None: self.batches = batches self.partitions = partitions if devices is None: devices = [torch.device('cpu') for _ in partitions] self.devices = devices if copy_streams is None: copy_streams = [[current_stream(d)] * len(batches) for d in devices] self.copy_streams = copy_streams if skip_layout is None: skip_layout = inspect_skip_layout(partitions) self.skip_layout = skip_layout self.checkpoint_stop = checkpoint_stop
def backward(ctx: Context, *grad_output: Tensor, ) -> Tuple[Optional[Tensor], ...]: prev_stream = ctx.prev_stream next_stream = ctx.next_stream grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) input_stream = current_stream(get_device(prev_stream)) with use_stream(prev_stream), use_stream(next_stream): for x in reversed(grad_output): y = x.to(get_device(prev_stream)) grad_input.appendleft(y) # 'next_stream' is not where 'x' has been allocated. record_stream(x, next_stream) # 'y' has been allocated on 'prev_stream'. # It might be used on the current stream captured as 'input_stream'. record_stream(y, input_stream) grad_streams: Tuple[Optional[Tensor], ...] = (None, None) return grad_streams + tuple(grad_input)
def forward(ctx: Context, # type: ignore prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor, ) -> Tensors: ctx.prev_stream = prev_stream ctx.next_stream = next_stream output = [] output_stream = current_stream(get_device(next_stream)) with use_stream(prev_stream), use_stream(next_stream): for x in input: y = x.to(get_device(next_stream)) output.append(y) # 'prev_stream' is not where 'x' has been allocated. record_stream(x, prev_stream) # 'y' has been allocated on 'next_stream'. # It might be used on the current stream captured as 'output_stream'. record_stream(y, output_stream) return tuple(output)
def compute( self, schedule: List[Tuple[int, int]], in_queues: List[InQueue], out_queues: List[OutQueue], ) -> None: """Runs tasks with synchronization to copy streams.""" batches = self.batches partitions = self.partitions devices = self.devices copy_streams = self.copy_streams checkpoint_stop = self.checkpoint_stop n = len(partitions) streams = [current_stream(d) for d in devices] exc_info: Optional[ExcInfo] = None # With checkpointing, the autograd graph looks like this diagram: # ┌─────┸──────┐ # │ Copy │ # └─────┰──────┘ (fence) # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┃ (compute) # ┌─────┸──────┐ # │ Wait │ [1] Synchronize the current stream with the copy stream. # └─────┰──────┘ # ┌─────┸──────┐ # │ Checkpoint │ [2] Compute a partition within checkpointing. # └─────┰──────┘ # ┌─────┸──────┐ # │ Wait │ [3] Synchronize the copy stream with the current stream. # └─────┰──────┘ # ┠ ─ ─ ─ ┐ # ┃ ┌─────┴─────┐ # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation. # ┃ └─────┬─────┘ # ┠ ─ ─ ─ ┘ # ┃ # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┌─────┸──────┐ (fence) # │ Copy │ # └─────┰──────┘ for i, j in schedule: batch = batches[j] partition = partitions[i] device = devices[i] # Synchronize with the copied input. ([1] in the diagram) if i != 0: wait(batch, copy_streams[i][j], streams[i]) # Determine whether checkpointing or not. checkpoint = (j < checkpoint_stop) if checkpoint: chk = Checkpointing(partition, batch) task = Task(device, streams[i], compute=chk.checkpoint, finalize=chk.recompute) del chk else: def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch: return batch.call(partition) task = Task(device, streams[i], compute=compute, finalize=None) del compute # Compute tasks in parallel. ([2] in the diagram) in_queues[i].put(task) for i, j in schedule: ok, payload = out_queues[i].get() # Hold the first exception. if exc_info is not None: continue elif not ok: exc_info = cast(ExcInfo, payload) continue task, batch = cast(Tuple[Task, Batch], payload) # The copy stream synchronizes to copy the output. ([3] in the # diagram) if i != n - 1: wait(batch, streams[i], copy_streams[i][j]) # Finalize tasks. If checkpointing is enabled, here the # recomputation is scheduled at backpropagation. ([4] in the # diagram) task.finalize(batch) batches[j] = batch # Fail at the first exception. if exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
def test_copy_wait_cuda_cuda(cuda_sleep): prev_stream = current_stream(torch.device('cuda')) next_stream = new_stream(torch.device('cuda')) _test_copy_wait(prev_stream, next_stream, cuda_sleep)
def test_get_device_cuda(self): stream = current_stream(torch.device('cuda')) assert get_device(stream).type == 'cuda'
def test_use_stream_cuda(self): stream = new_stream(torch.device('cuda')) with use_stream(stream): assert current_stream(torch.device('cuda')) == stream
def test_current_stream_cuda(self): stream = current_stream(torch.device('cuda')) assert isinstance(stream, torch.cuda.Stream) assert stream == torch.cuda.current_stream()
def test_current_stream_cpu(self): stream = current_stream(torch.device('cpu')) assert stream is CPUStream
def test_wait_stream_cuda_cuda(self, cuda_sleep): source = current_stream(torch.device('cuda')) target = new_stream(torch.device('cuda')) self._test_wait_stream(source, target, cuda_sleep)