def devices(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) c = nn.Linear(1, 1) # There are extra two ranks. model = nn.Sequential(a, b, c) model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) # Extra devices must be discarded. if model.group.rank() == 3: assert model.pipeline is None
def checkpoint_mode_invalid(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises( ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" ): Pipe( model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="INVALID_CHECKPOINT", )
def inplace_on_not_requires_grad(pipeline_style): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) y = model(x) del model message = r"a leaf Variable that requires grad .* used in an in-place operation." with pytest.raises(RuntimeError, match=message): y.backward() torch.distributed.barrier()
def async_event_loop(): model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10) inputs = torch.rand(100, 10) output = pipe(inputs) if pipe.final_stage: loss = output.mean() loss.backward()
def test_verify_module_duplicate_parameters_on_distinct_devices(): class Surrogate(nn.Module): def __init__(self, module): super().__init__() self.module = module conv = nn.Conv2d(3, 3, 1) model = nn.Sequential(Surrogate(conv), Surrogate(conv)) with pytest.raises( ValueError, match= "module with duplicate parameters on distinct devices is not supported" ): Pipe(model, [1, 1], devices=["cpu", "cuda"])
def checkpoint_mode_when_chunks_1(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) # All checkpoint modes are fine. Pipe( model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", ) Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always") Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never")
def test_1to3(balance, checkpoint): if torch.cuda.device_count() < len(balance): pytest.skip("at least %d cuda devices required" % len(balance)) @skippable(stash=["1to3"]) class Layer1(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): yield stash("1to3", input) output = self.conv(input) return output class Layer2(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): output = self.conv(input) return output @skippable(pop=["1to3"]) class Layer3(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): skip_1to3 = yield pop("1to3") output = self.conv(input) + skip_1to3 return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = Pipe(model, balance, chunks=3, checkpoint=checkpoint) in_device = model.devices[0] out_device = model.devices[-1] input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) loss = output.mean() loss.backward() assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=5e-1) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
def test_deferred_batch_norm(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe(nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True) x = torch.rand(4, 3, 10, 10) pipe(x).mean().backward() bn(x).mean().backward() assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
def test_non_tensor_tuple(): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) model = Pipe(model, balance=[1], devices=["cpu"]) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model((x, "hello"))
def test_non_tensor(): class NonTensor(nn.Module): def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) model = Pipe(model, balance=[1], devices=["cpu"]) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model("hello")
def inplace_on_requires_grad(): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) y = model(x) message = r"a leaf Variable that requires grad .* used in an in-place operation." if model.group.rank() == 1: with pytest.raises(RuntimeError, match=message): y.backward() torch.distributed.barrier()
def non_tensor(): class NonTensor(nn.Module): def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model("hello")
def non_tensor_tuple(): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model((x, "hello"))
def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): class Surrogate(nn.Module): def __init__(self, module): super().__init__() self.module = module conv = nn.Conv2d(3, 3, 1) model = nn.Sequential(Surrogate(conv), Surrogate(conv)) # FIXME(tom) can't have duplicate params with separate processes with pytest.raises( ValueError, match= "module with duplicate parameters on distinct devices is not supported" ): Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
def exception(pipeline_style): class ExpectedException(Exception): pass class Raise(nn.Module): def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1))
def test_sequential_like(balance): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = Pipe(model, balance, devices=["cpu", "cpu"]) assert len(model) == 2 assert list(model) == [a, b] assert model[0] is a assert model[1] is b with pytest.raises(IndexError): _ = model[2] assert model[-1] is b assert model[-2] is a
def lazy_skippable_error(): """Using skippable layers in combination with lazy construction is currently not supported, check that it raises an Exception""" @skippable(stash=["1to3"]) class Layer1(nn.Linear): pass @skippable(pop=["1to3"]) class Layer3(nn.Linear): pass model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)] with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): Pipe( model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), )
def test_parallel_randoms(): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): x = F.dropout(x, p=0.001) return x model = nn.Sequential(Dropouts(), Dropouts()) x = torch.rand(10, 10, requires_grad=True) model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always") y = model(x) y.norm().backward() assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
def test_checkpoint_non_float_input(): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) class JoinNonFloat(nn.Module): def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always") input = torch.rand(1, requires_grad=True) output = model(input) output.backward()
def test_deferred_batch_norm_params(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe(nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True) x = torch.rand(4, 3, 10, 10) pipe(x).mean().backward() bn(x).mean().backward() assert pipe[0].weight.grad is not None assert pipe[0].bias.grad is not None assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
def test_public_attrs(): class MyString: def __init__(self, value): self.value = value def __str__(self): return self.value model = nn.Sequential(nn.Linear(1, 1)) pipe = Pipe(model, balance=(1, ), devices=("cpu", ), chunks=42.000, checkpoint=MyString("always")) assert pipe.balance == [1] assert pipe.devices == [torch.device("cpu")] assert pipe.chunks == 42 assert isinstance(pipe.chunks, int) assert pipe.checkpoint == "always" assert isinstance(pipe.checkpoint, str)
def test_no_grad(): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) input = torch.rand(2, 1) latent = None def hook(module, input, output): _ = module _ = input nonlocal latent latent = output partition = model.partitions[0] partition.register_forward_hook(hook) with torch.no_grad(): model(input) assert latent.grad_fn is None
def no_grad(): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2) input = torch.rand(2, 1) latent = None def hook(module, input, output): _ = module _ = input nonlocal latent latent = output partition = model.partitions[0] partition.register_forward_hook(hook) with torch.no_grad(): model(input) assert latent.grad_fn is None
def test_tuple_wait(cuda_sleep): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized # properly to the copy stream. class Sleep(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.detach() @staticmethod def backward(ctx, grad): with torch.cuda.device(grad.device): cuda_sleep(0.05) return grad class Layer1(nn.Module): def forward(self, pair): a, b = pair return a * 1, b * 2, b * 3 class Layer2(nn.Module): def forward(self, triple): a, b, c = triple b = Sleep.apply(b) return a + b + c model = nn.Sequential(Layer1(), Layer2()) model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never") a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) y.norm().backward() torch.cuda.synchronize(0) torch.cuda.synchronize(1) assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
def test_exception_no_hang(): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal # message. So the former partition was blocked at out_queue.join() for the # next of next micro-batch. class ExpectedException(Exception): pass class Pass(nn.Module): def forward(self, x): return x class Raise(nn.Module): def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3))
def sequential_like(balance, pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map()) if balance == [2]: if torch.distributed.get_rank() == 0: assert len(model) == 2 assert list(model) == [a, b] assert model[0] is a assert model[1] is b with pytest.raises(IndexError): _ = model[2] assert model[-1] is b assert model[-2] is a else: assert len(model) == 0 assert list(model) == [] else: assert len(model) == 1 if torch.distributed.get_rank() == 0: assert list(model) == [a] assert model[0] is a assert model[-1] is a else: assert list(model) == [b] assert model[0] is b assert model[-1] is b with pytest.raises(IndexError): _ = model[1]
def test_input_pair(): class Two(nn.Module): def __init__(self): super().__init__() self.fc_a = nn.Linear(1, 1) self.fc_b = nn.Linear(1, 1) def forward(self, a_and_b): a, b = a_and_b return (self.fc_a(a), self.fc_b(b)) model = nn.Sequential(Two()) model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) a_out, b_out = model((a, b)) loss = (a_out + b_out).mean() loss.backward() assert a.grad is not None assert b.grad is not None
def exception_early_stop_asap(pipeline_style): """Even the first partitions have finished to process, the partition before the failed partition hould be killed as soon as possible. """ class ExpectedExceptio(Exception): pass class Pass(nn.Module): def forward(self, x): return x counter = 0 class Counter(nn.Module): def forward(self, x): time.sleep(0.1) nonlocal counter counter += 1 return x class Raise(nn.Module): def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) # If the early stop doesn't work, it would be 3 instead. assert counter == 2
def deferred_batch_norm(checkpoint, lazy): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe_fn = lambda: pipe_bn # noqa: E731 if lazy: model = [pipe_fn] else: model = nn.Sequential(pipe_bn) pipe = Pipe( model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True, ) x = torch.rand(4, 3, 10, 10) pipe(x).mean().backward() bn(x).mean().backward() assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
def test_none_skip(): @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): yield stash("none", None) return input @skippable(pop=["none"]) class Pop(nn.Module): def forward(self, input): none = yield pop("none") assert none is None return input model = nn.Sequential(Stash(), Pop()) model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5) input = torch.rand(10, requires_grad=True) output = model(input) def assert_grad_fn_is_not_portal(grad_fn, visited=set()): if grad_fn in visited or grad_fn is None: return assert not isinstance(grad_fn, PortalBlue._backward_cls) assert not isinstance(grad_fn, PortalCopy._backward_cls) assert not isinstance(grad_fn, PortalOrange._backward_cls) visited.add(grad_fn) for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) assert_grad_fn_is_not_portal(output.grad_fn) output.sum().backward() assert input.grad.mean().item() == 1