Example #1
0
    def __init__(
        self,
        batches: List[Batch],
        partitions: List[nn.Sequential],
        devices: Optional[List[torch.device]] = None,
        copy_streams: Optional[List[List[AbstractStream]]] = None,
        skip_layout: Optional[SkipLayout] = None,
        checkpoint_stop: int = 0,
    ) -> None:
        self.batches = batches
        self.partitions = partitions

        if devices is None:
            devices = [torch.device('cpu') for _ in partitions]
        self.devices = devices

        if copy_streams is None:
            copy_streams = [[current_stream(d)] * len(batches)
                            for d in devices]
        self.copy_streams = copy_streams

        if skip_layout is None:
            skip_layout = inspect_skip_layout(partitions)

        self.skip_layout = skip_layout
        self.checkpoint_stop = checkpoint_stop
Example #2
0
    def backward(ctx: Context,
                 *grad_output: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
        input_stream = current_stream(get_device(prev_stream))

        with use_stream(prev_stream), use_stream(next_stream):
            for x in reversed(grad_output):
                y = x.to(get_device(prev_stream))
                grad_input.appendleft(y)

                # 'next_stream' is not where 'x' has been allocated.
                record_stream(x, next_stream)
                # 'y' has been allocated on 'prev_stream'.
                # It might be used on the current stream captured as 'input_stream'.
                record_stream(y, input_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams + tuple(grad_input)
Example #3
0
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        output = []
        output_stream = current_stream(get_device(next_stream))

        with use_stream(prev_stream), use_stream(next_stream):
            for x in input:
                y = x.to(get_device(next_stream))
                output.append(y)

                # 'prev_stream' is not where 'x' has been allocated.
                record_stream(x, prev_stream)
                # 'y' has been allocated on 'next_stream'.
                # It might be used on the current stream captured as 'output_stream'.
                record_stream(y, output_stream)

        return tuple(output)
Example #4
0
    def compute(
        self,
        schedule: List[Tuple[int, int]],
        in_queues: List[InQueue],
        out_queues: List[OutQueue],
    ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[j]
            partition = partitions[i]
            device = devices[i]

            # Synchronize with the copied input. ([1] in the diagram)
            if i != 0:
                wait(batch, copy_streams[i][j], streams[i])

            # Determine whether checkpointing or not.
            checkpoint = (j < checkpoint_stop)
            if checkpoint:
                chk = Checkpointing(partition, batch)
                task = Task(device,
                            streams[i],
                            compute=chk.checkpoint,
                            finalize=chk.recompute)
                del chk

            else:

                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition) -> Batch:
                    return batch.call(partition)

                task = Task(device, streams[i], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[i].put(task)

        for i, j in schedule:
            ok, payload = out_queues[i].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if i != n - 1:
                wait(batch, streams[i], copy_streams[i][j])

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            task.finalize(batch)

            batches[j] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
Example #5
0
def test_copy_wait_cuda_cuda(cuda_sleep):
    prev_stream = current_stream(torch.device('cuda'))
    next_stream = new_stream(torch.device('cuda'))
    _test_copy_wait(prev_stream, next_stream, cuda_sleep)
Example #6
0
 def test_get_device_cuda(self):
     stream = current_stream(torch.device('cuda'))
     assert get_device(stream).type == 'cuda'
Example #7
0
 def test_use_stream_cuda(self):
     stream = new_stream(torch.device('cuda'))
     with use_stream(stream):
         assert current_stream(torch.device('cuda')) == stream
Example #8
0
 def test_current_stream_cuda(self):
     stream = current_stream(torch.device('cuda'))
     assert isinstance(stream, torch.cuda.Stream)
     assert stream == torch.cuda.current_stream()
Example #9
0
 def test_current_stream_cpu(self):
     stream = current_stream(torch.device('cpu'))
     assert stream is CPUStream
Example #10
0
 def test_wait_stream_cuda_cuda(self, cuda_sleep):
     source = current_stream(torch.device('cuda'))
     target = new_stream(torch.device('cuda'))
     self._test_wait_stream(source, target, cuda_sleep)