def test_fork_join(): logs = [] class Log(torch.autograd.Function): @staticmethod def forward(ctx, number, tensor): ctx.number = number return tensor.detach() @staticmethod def backward(ctx, grad): logs.append(ctx.number) return None, grad a = torch.rand(1, device='cpu', requires_grad=True) b = torch.rand(1, device='cuda', requires_grad=True) a = Log.apply(1, a) a, phony = fork(a) b = join(a, phony) b = Log.apply(2, b) b = b.to('cpu') (a + b).backward() assert logs == [2, 1]
def recompute(self, batch: Batch) -> None: """Applies :class:`Recompute` to the batch in place.""" input_atomic = self.batch.atomic input = tuple(self.batch) # batch[0] is always requiring grad, because it has been passed # checkpoint with a phony requiring grad. batch[0], phony = fork(batch[0]) phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) batch[0] = join(batch[0], phony)
def test_join_when_fork_not_requires_grad(): x = torch.rand(2, 1) a, b = x.chunk(2) assert not a.requires_grad a, p = fork(a) assert not a.requires_grad assert not p.requires_grad assert not b.requires_grad b = join(b, p) assert not b.requires_grad
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: """Loads a skip tensor from the corresponding portal to pop. The given micro-batch is connected to the portal with :class:`Fork`. """ if not self.skip_layout.requires_copy(ns, name): tensor = super().load(batch, ns, name) return tensor portal = self.portals[(ns, name)] batch[0], phony = fork(batch[0]) tensor = portal.orange(phony) return tensor
def test_fork_join_no_grad(monkeypatch): def do_not_apply(*args): raise AssertionError('Function.apply called') monkeypatch.setattr('torch.autograd.Function.apply', do_not_apply) x = torch.rand(1, requires_grad=True) with torch.no_grad(): x2, p = fork(x) assert not p.requires_grad assert x2 is x x = x2 with torch.no_grad(): x2 = join(x, p) assert x2 is x x = x2
def copy( self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, ) -> None: """Copies the skip tensor in the corresponding portal. The given micro-batch and the portal will be tied with :class:`Fork` and :class:`Join`. """ assert self.skip_layout.requires_copy(ns, name) batch[0], phony = fork(batch[0]) portal = self.portals[(ns, name)] phony = portal.copy(prev_stream, next_stream, phony) batch[0] = join(batch[0], phony)
def test_serial_checkpoints(device): # Copied from https://github.com/pytorch/pytorch/pull/18568. timeline = [] class Log(torch.autograd.Function): @staticmethod def forward(ctx, name, x): ctx.name = name timeline.append(f'{name}:forward') return x.detach() @staticmethod def backward(ctx, grad_output): name = ctx.name timeline.append(f'{name}:backward') return None, grad_output a = torch.rand(1, device=device, requires_grad=True) b = torch.rand(1, device=device, requires_grad=True) # Increase the next function sequence number. _ = a + 1 + 2 + 3 + 4 + 5 a = checkpoint(partial(Log.apply, 'a'), a) a, phony = fork(a) b = join(b, phony) b = checkpoint(partial(Log.apply, 'b'), b) c = torch.cat((a, b)) out = c.sum() # +--> {a} --Checkpoint(Log)--> {a} # {out} --Sum--> {c} --Cat ^-----------------------------+ # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} out.backward() assert timeline == \ ['a:forward', 'b:forward', 'b:forward', 'b:backward', 'a:forward', 'a:backward']
def test_fork_leak(): leak = None class F(torch.autograd.Function): @staticmethod def forward(ctx, input): return input @staticmethod def backward(ctx, grad): nonlocal leak leak = weakref.ref(ctx) return grad x = torch.rand(1, requires_grad=True) x = F.apply(x) x, phony = fork(x) x = join(x, phony) x.backward() del x, phony assert leak() is None
def test_fork_join_enable_grad(): x = torch.rand(1, requires_grad=True) with torch.enable_grad(): x2, p = fork(x) assert p.requires_grad assert x2 is not x x = x2 assert x.requires_grad assert p.requires_grad assert x.grad_fn.__class__ is Fork._backward_cls assert p.grad_fn.__class__ is Fork._backward_cls with torch.enable_grad(): x2 = join(x, p) assert x2 is not x x = x2 assert x.requires_grad assert x.grad_fn.__class__ is Join._backward_cls
def test_blue_orange_not_requires_grad(): tensor1 = torch.rand(1, requires_grad=True) tensor2 = torch.rand(1) # Same with: output = tensor1*2 + tensor2 # # +----------------------+ # | | # tensor2 -- PortalBlue -+ +- PortalOrange -+ # | | | # tensor1 ------------ Join -- Fork --- Mul --- Add -- output # main = tensor1 portal = Portal(tensor2, tensor_life=2) phony = portal.blue() main = join(main, phony) main, phony = fork(main) sub = portal.orange(phony) output = main * 2 + sub output.backward() assert torch.allclose(tensor1.grad, torch.tensor([2.])) assert tensor2.grad is None
def depend(fork_from: Batch, join_to: Batch) -> None: fork_from[0], phony = fork(fork_from[0]) join_to[0] = join(join_to[0], phony)