Example #1
0
def test_copy_returns_on_next_device():
    portal = Portal(torch.rand(1), tensor_life=1)

    prev_stream = default_stream(torch.device("cpu"))
    next_stream = default_stream(torch.device("cuda"))

    phony = torch.zeros(0, requires_grad=True)
    assert phony.device.type == "cpu"

    phony = portal.copy(prev_stream, next_stream, phony)
    assert phony.device.type == "cuda"
Example #2
0
def test_use_grad():
    tensor = torch.rand(1, requires_grad=True)
    portal = Portal(tensor, tensor_life=1)

    portal.put_grad(tensor)
    assert portal.use_grad() is tensor

    # Gradient in a portal is ephemeral.
    with pytest.raises(RuntimeError):
        portal.use_grad()
Example #3
0
def test_blue_orange_not_requires_grad():
    tensor1 = torch.rand(1, requires_grad=True)
    tensor2 = torch.rand(1)

    # Same with: output = tensor1*2 + tensor2
    #
    #                +----------------------+
    #                |                      |
    # tensor2 -- PortalBlue -+      +- PortalOrange -+
    #                        |      |                |
    # tensor1 ------------ Join -- Fork --- Mul --- Add -- output
    #
    main = tensor1
    portal = Portal(tensor2, tensor_life=2)
    phony = portal.blue()
    main = join(main, phony)
    main, phony = fork(main)
    sub = portal.orange(phony)
    output = main * 2 + sub

    output.backward()

    assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
    assert tensor2.grad is None
Example #4
0
 def new_portal(tensor_life):
     nonlocal portal
     tensor = torch.rand(1, requires_grad=True)
     portal = Portal(tensor, tensor_life)
     return portal, tensor