def test_cached_phony(): p1 = get_phony(torch.device('cpu'), requires_grad=True) p2 = get_phony(torch.device('cpu'), requires_grad=True) assert p1 is p2 p3 = get_phony(torch.device('cpu'), requires_grad=False) p4 = get_phony(torch.device('cpu'), requires_grad=False) assert p3 is p4 assert p1 is not p3
def fork(input: Tensor) -> Tuple[Tensor, Tensor]: """Branches out from an autograd lane of the given tensor.""" if input.requires_grad and torch.is_grad_enabled(): input, phony = Fork.apply(input) else: phony = get_phony(input.device, requires_grad=False) return input, phony
def forward( ctx: Context, # type: ignore portal: Portal, # This tensor must be retrieved by portal.use_tensor(). tensor: Tensor, ) -> Tensor: ctx.portal = portal phony = get_phony(tensor.device, requires_grad=False) return phony.detach()
def checkpoint(self) -> Batch: """Returns a batch applied by :class:`Checkpoint`.""" input_atomic = self.batch.atomic input = tuple(self.batch) # Use a phony which requires grad to ensure that Checkpoint can be # tracked by the autograd engine even when none of the input tensors # require grad. phony = get_phony(self.batch[0].device, requires_grad=True) output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) return Batch(output)
def test_phony_in_autograd_function(): class Phonify(torch.autograd.Function): @staticmethod def forward(ctx, input): phony = get_phony(input.device, requires_grad=False) return phony.detach() x = torch.rand(1, requires_grad=True) p1 = Phonify.apply(x) p2 = get_phony(torch.device('cpu'), requires_grad=True) assert p1 is not p2 assert p1.grad_fn is not None assert p2.grad_fn is None
def forward( ctx: Context, # type: ignore portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, ) -> Tensor: ctx.portal = portal assert portal.tensor is not None portal.tensor, = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) phony = get_phony(get_device(next_stream), requires_grad=False) return phony.detach()
def blue(self) -> Tensor: """Creates a :class:`PortalBlue` which hides the underlying tensor from the autograd engine. Join the returning phony to the main lane of the autograd graph to assure the correct backpropagation:: PortalBlue --+ | ---------- Join -- """ tensor = self.use_tensor() if tensor is None: return get_phony(torch.device('cpu'), requires_grad=False) return PortalBlue.apply(self, tensor)
def copy( self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, ) -> Tensor: """Copies the hidden tensor by a :class:`PortalCopy`. Give a phony and use the returning phony to keep backpropagation:: +-- PortalCopy --+ | | -- Fork ---------- Join -- """ if self.tensor is None: return get_phony(torch.device('cpu'), requires_grad=False) return PortalCopy.apply(self, prev_stream, next_stream, phony)
def test_phony_size(): p = get_phony(torch.device('cpu'), requires_grad=False) assert p.size() == (0, )
def forward(ctx, input): phony = get_phony(input.device, requires_grad=False) return phony.detach()
def test_phony_requires_grad(): p1 = get_phony(torch.device('cpu'), requires_grad=True) p2 = get_phony(torch.device('cpu'), requires_grad=False) assert p1.requires_grad assert not p2.requires_grad
def forward(ctx: 'Fork', input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore phony = get_phony(input.device, requires_grad=False) return input.detach(), phony.detach()