def _test_wait_stream(self, source, target, cuda_sleep=None): with use_stream(target): if is_cuda(target): cuda_sleep(0.5) x = torch.ones(100, 100, device=get_device(target)) wait_stream(source, target) with use_stream(source): assert x.sum().item() == 10000
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): device = get_device(prev_stream) with use_stream(prev_stream): if is_cuda(prev_stream): cuda_sleep(0.5) x = torch.ones(100, device=device, requires_grad=True) (y, ) = Copy.apply(prev_stream, next_stream, x) (y, ) = Wait.apply(prev_stream, next_stream, x) with use_stream(next_stream): assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) y.norm().backward() with use_stream(prev_stream): assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device))
def test_get_device_cuda(self): stream = current_stream(torch.device("cuda")) assert get_device(stream).type == "cuda"
def test_get_device_cpu(self): assert get_device(CPUStream).type == "cpu"