Example #1
0
    def test_record_stream_cuda(self, cuda_sleep):
        # This test detects unexpected block reallocation. For reliable test,
        # the stream to allocate tensors is isolated. The allocator will not
        # reuse free blocks which were allocated from another stream.
        stream_alloc = new_stream(torch.device('cuda'))
        with torch.cuda.stream(stream_alloc):
            x = torch.rand(1, device=torch.device('cuda'))

        stream = new_stream(torch.device('cuda'))
        record_stream(x, stream)
        with use_stream(stream):
            cuda_sleep(0.5)

        # 'x' is deleted at Python's perspective. But the block of 'x' is still
        # required for 'stream'. 'y' shouldn't be allocated to the block.
        data_ptr = x.data_ptr()
        del x
        stream_alloc.synchronize()
        with torch.cuda.stream(stream_alloc):
            y = torch.rand(1, device=torch.device('cuda'))
        assert y.data_ptr() != data_ptr

        # Pause Python until 'stream' finishes tasks queued. Now the block of
        # 'x' is free to be reallocated.
        wait_stream(CPUStream, stream)
        with torch.cuda.stream(stream_alloc):
            z = torch.rand(1, device=torch.device('cuda'))
        assert z.data_ptr() == data_ptr
Example #2
0
    def backward(ctx: Context,
                 *grad_input: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        wait_stream(prev_stream, next_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams + grad_input
Example #3
0
    def backward(
        ctx: Context,  # type: ignore
        grad_input: Tensor,
    ) -> Tuple[None, None, Tensor]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        wait_stream(prev_stream, next_stream)

        return None, None, grad_input
Example #4
0
    def _test_wait_stream(self, source, target, cuda_sleep=None):
        with use_stream(target):
            if is_cuda(target):
                cuda_sleep(0.5)
            x = torch.ones(100, 100, device=get_device(target))

        wait_stream(source, target)

        with use_stream(source):
            assert x.sum().item() == 10000
Example #5
0
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        wait_stream(next_stream, prev_stream)

        return tuple(x.detach() for x in input)
Example #6
0
    def forward(
        ctx: Context,  # type: ignore
        prev_stream: AbstractStream,
        next_stream: AbstractStream,
        input: Tensor,
    ) -> Tensor:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        wait_stream(next_stream, prev_stream)

        return input.detach()