def test_fork_join(): logs = [] class Log(torch.autograd.Function): @staticmethod def forward(ctx, number, tensor): ctx.number = number return tensor.detach() @staticmethod def backward(ctx, grad): logs.append(ctx.number) return None, grad a = torch.rand(1, device="cpu", requires_grad=True) b = torch.rand(1, device="cuda", requires_grad=True) a = Log.apply(1, a) a, phony = fork(a) b = join(a, phony) b = Log.apply(2, b) b = b.to("cpu") (a + b).backward() assert logs == [2, 1]
def test_join_when_fork_not_requires_grad(): x = torch.rand(2, 1) a, b = x.chunk(2) assert not a.requires_grad a, p = fork(a) assert not a.requires_grad assert not p.requires_grad assert not b.requires_grad b = join(b, p) assert not b.requires_grad
def test_serial_checkpoints(device): # Copied from https://github.com/pytorch/pytorch/pull/18568. timeline = [] class Log(torch.autograd.Function): @staticmethod def forward(ctx, name, x): ctx.name = name timeline.append(f"{name}:forward") return x.detach() @staticmethod def backward(ctx, grad_output): name = ctx.name timeline.append(f"{name}:backward") return None, grad_output a = torch.rand(1, device=device, requires_grad=True) b = torch.rand(1, device=device, requires_grad=True) # Increase the next function sequence number. _ = a + 1 + 2 + 3 + 4 + 5 a = checkpoint(partial(Log.apply, "a"), a) a, phony = fork(a) b = join(b, phony) b = checkpoint(partial(Log.apply, "b"), b) c = torch.cat((a, b)) out = c.sum() # +--> {a} --Checkpoint(Log)--> {a} # {out} --Sum--> {c} --Cat ^-----------------------------+ # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} out.backward() assert timeline == [ "a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward" ]
def test_fork_join_no_grad(monkeypatch): def do_not_apply(*args): raise AssertionError("Function.apply called") monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) x = torch.rand(1, requires_grad=True) with torch.no_grad(): x2, p = fork(x) assert not p.requires_grad assert x2 is x x = x2 with torch.no_grad(): x2 = join(x, p) assert x2 is x x = x2
def test_fork_leak(): leak = None class F(torch.autograd.Function): @staticmethod def forward(ctx, input): return input @staticmethod def backward(ctx, grad): nonlocal leak leak = weakref.ref(ctx) return grad x = torch.rand(1, requires_grad=True) x = F.apply(x) x, phony = fork(x) x = join(x, phony) x.backward() del x, phony assert leak() is None
def test_fork_join_enable_grad(): x = torch.rand(1, requires_grad=True) with torch.enable_grad(): x2, p = fork(x) assert p.requires_grad assert x2 is not x x = x2 assert x.requires_grad assert p.requires_grad assert x.grad_fn.__class__ is Fork._backward_cls assert p.grad_fn.__class__ is Fork._backward_cls with torch.enable_grad(): x2 = join(x, p) assert x2 is not x x = x2 assert x.requires_grad assert x.grad_fn.__class__ is Join._backward_cls
def test_blue_orange_not_requires_grad(): tensor1 = torch.rand(1, requires_grad=True) tensor2 = torch.rand(1) # Same with: output = tensor1*2 + tensor2 # # +----------------------+ # | | # tensor2 -- PortalBlue -+ +- PortalOrange -+ # | | | # tensor1 ------------ Join -- Fork --- Mul --- Add -- output # main = tensor1 portal = Portal(tensor2, tensor_life=2) phony = portal.blue() main = join(main, phony) main, phony = fork(main) sub = portal.orange(phony) output = main * 2 + sub output.backward() assert torch.allclose(tensor1.grad, torch.tensor([2.0])) assert tensor2.grad is None