def test_copy_returns_on_next_device(): portal = Portal(torch.rand(1), tensor_life=1) prev_stream = default_stream(torch.device("cpu")) next_stream = default_stream(torch.device("cuda")) phony = torch.zeros(0, requires_grad=True) assert phony.device.type == "cpu" phony = portal.copy(prev_stream, next_stream, phony) assert phony.device.type == "cuda"
def test_use_grad(): tensor = torch.rand(1, requires_grad=True) portal = Portal(tensor, tensor_life=1) portal.put_grad(tensor) assert portal.use_grad() is tensor # Gradient in a portal is ephemeral. with pytest.raises(RuntimeError): portal.use_grad()
def test_blue_orange_not_requires_grad(): tensor1 = torch.rand(1, requires_grad=True) tensor2 = torch.rand(1) # Same with: output = tensor1*2 + tensor2 # # +----------------------+ # | | # tensor2 -- PortalBlue -+ +- PortalOrange -+ # | | | # tensor1 ------------ Join -- Fork --- Mul --- Add -- output # main = tensor1 portal = Portal(tensor2, tensor_life=2) phony = portal.blue() main = join(main, phony) main, phony = fork(main) sub = portal.orange(phony) output = main * 2 + sub output.backward() assert torch.allclose(tensor1.grad, torch.tensor([2.0])) assert tensor2.grad is None
def new_portal(tensor_life): nonlocal portal tensor = torch.rand(1, requires_grad=True) portal = Portal(tensor, tensor_life) return portal, tensor