def test_empty_module(): # Empty sequential module is not illegal. model = nn.Sequential() model = GPipe(model, []) assert model(torch.tensor(42)) == torch.tensor(42) assert model((torch.tensor(42), )) == (torch.tensor(42), ) # But only tensor or tensors is legal in GPipe. with pytest.raises(TypeError): model(42)
def test_inplace_on_not_requires_grad(): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) model = GPipe(model, [1], devices=['cpu'], checkpoint='always') x = torch.rand(1) y = model(x) message = 'a leaf Variable that requires grad has been used in an in-place operation.' with pytest.raises(RuntimeError, match=message): y.backward()
def test_checkpoint_mode_invalid(): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises( ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" ): GPipe(model, balance=[1], devices=['cpu'], chunks=2, checkpoint='INVALID_CHECKPOINT')
def naive2(model: nn.Module, devices: List[int]) -> Stuffs: batch_size = 47 balance = [84, 241] model = cast(nn.Sequential, model) # GPipe with chunks=1, checkpoint='never' is equivalent to a typical model parallel. model = GPipe(model, balance, devices=devices, chunks=1, checkpoint='never') return model, batch_size, list(model.devices)
def test_input_singleton(): class One(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(1, 1) def forward(self, only_a): a, = only_a return (self.fc(a), ) model = nn.Sequential(One()) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) a = torch.rand(10, 1, requires_grad=True) a_out, = model((a, )) loss = a_out.mean() loss.backward() assert all(p.grad is not None for p in model.parameters()) assert a.grad is not None
def pipeline1(model: nn.Module, devices: List[int]) -> Stuffs: batch_size = 96 chunks = 1 balance = [370] model = cast(nn.Sequential, model) model = GPipe(model, balance, devices=devices, chunks=chunks, checkpoint='always') return model, batch_size, list(model.devices)
def pipeline8(devices: List[int]) -> Stuffs: B, C = 48, 160 balance = [800, 140, 62, 36, 36, 36, 36, 987] model: nn.Module = unet(depth=5, num_convs=B, base_channels=C, input_channels=3, output_channels=1) model = cast(nn.Sequential, model) model = GPipe(model, balance, devices=devices, chunks=128) return model, B, C, list(model.devices)
def test_exception(): class ExpectedException(Exception): pass class Raise(nn.Module): def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) model = GPipe(model, balance=[1], devices=['cpu'], chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1))
def pipeline4(devices: List[int]) -> Stuffs: B, C = 24, 160 balance = [472, 54, 36, 515] model: nn.Module = unet(depth=5, num_convs=B, base_channels=C, input_channels=3, output_channels=1) model = cast(nn.Sequential, model) model = GPipe(model, balance, devices=devices, chunks=32) return model, B, C, list(model.devices)
def test_1to3(balance, checkpoint): if torch.cuda.device_count() < len(balance): pytest.skip('at least %d cuda devices required' % len(balance)) @skippable(stash=['1to3']) class Layer1(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): yield stash('1to3', input) output = self.conv(input) return output class Layer2(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): output = self.conv(input) return output @skippable(pop=['1to3']) class Layer3(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): skip_1to3 = yield pop('1to3') output = self.conv(input) + skip_1to3 return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = GPipe(model, balance, chunks=3, checkpoint=checkpoint) in_device = model.devices[0] out_device = model.devices[-1] input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) loss = output.mean() loss.backward() assert torch.allclose(output.norm(), torch.tensor(1039.159, device=out_device)) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
def test_no_grad(): model = nn.Sequential(nn.Linear(1, 1)) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) input = torch.rand(2, 1) latent = None def hook(module, input, outputs): _ = module _ = input output, _ = outputs nonlocal latent latent = output partition = list(model.partitions())[0] partition.register_forward_hook(hook) with torch.no_grad(): model(input) assert latent.grad_fn is None
def test_devices(): a = nn.Linear(1, 1) b = nn.Linear(1, 1) c = nn.Linear(1, 1) # There are extra two devices. devices = ['cpu', 'cpu', 'cpu', 'cpu', 'cpu'] model = nn.Sequential(a, b, c) model = GPipe(model, [1, 1, 1], devices=devices) cpu = torch.device('cpu') # Extra devices must be discarded. assert model.devices == (cpu, cpu, cpu)
def test_identicalness(): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) def zero_grad(parameters): for p in parameters: p.grad = None inputs = torch.rand(8, 1) model = nn.Sequential( nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1), ) # Without GPipe outputs = model(inputs) loss = outputs.mean() loss.backward() grad_without_gpipe = sum_grad(model.parameters()) zero_grad(model.parameters()) # With GPipe model = GPipe(model, [2, 2], devices=['cpu', 'cpu'], chunks=4) outputs = model(inputs) loss = outputs.mean() loss.backward() grad_with_gpipe = sum_grad(model.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_gpipe, grad_without_gpipe)
def _gpipe( model: nn.Module, devices: List[int], batch_size: int, chunks: int, balance: List[int], checkpoint: str, ) -> Stuffs: model = cast(nn.Sequential, model) model = GPipe(model, balance, devices=devices, chunks=chunks, checkpoint=checkpoint) return model, batch_size, list(model.devices)
def test_non_tensor_tuple(): class NonTensorTuple(nn.Module): def forward(self, x): return (x, 'hello') model = nn.Sequential(NonTensorTuple()) model = GPipe(model, balance=[1], devices=['cpu']) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model((x, 'hello'))
def test_non_tensor(): class NonTensor(nn.Module): def forward(self, _): return 'hello' model = nn.Sequential(NonTensor()) model = GPipe(model, balance=[1], devices=['cpu']) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str with pytest.raises(TypeError): model(x) # TypeError: expected Tensor to scatter, but got str with pytest.raises(TypeError): model('hello')
def test_deferred_batch_norm(checkpoint): bn = nn.BatchNorm2d(3) gpipe_bn = deepcopy(bn) gpipe = GPipe(nn.Sequential(gpipe_bn), balance=[1], devices=['cpu'], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True) x = torch.rand(4, 3, 10, 10) gpipe(x).mean().backward() bn(x).mean().backward() assert torch.allclose(gpipe[0].running_mean, bn.running_mean, atol=1e-4) assert torch.allclose(gpipe[0].running_var, bn.running_var, atol=1e-4)
def test_sequential_like(balance): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = GPipe(model, balance, devices=['cpu', 'cpu']) assert len(model) == 2 assert list(model) == [a, b] assert model[0] is a assert model[1] is b with pytest.raises(IndexError): _ = model[2] assert model[-1] is b assert model[-2] is a
def test_parallel_randoms(): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): x = F.dropout(x, p=0.001) return x model = nn.Sequential(Dropouts(), Dropouts()) x = torch.rand(10, 10, requires_grad=True) model = GPipe(model, [1, 1], devices=['cpu', 'cpu'], chunks=10, checkpoint='always') y = model(x) y.norm().backward() assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
def test_deferred_batch_norm_params(checkpoint): bn = nn.BatchNorm2d(3) gpipe_bn = deepcopy(bn) gpipe = GPipe(nn.Sequential(gpipe_bn), balance=[1], devices=['cpu'], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True) x = torch.rand(4, 3, 10, 10) gpipe(x).mean().backward() bn(x).mean().backward() assert gpipe[0].weight.grad is not None assert gpipe[0].bias.grad is not None assert torch.allclose(gpipe[0].weight.grad, bn.weight.grad, atol=1e-4) assert torch.allclose(gpipe[0].bias.grad, bn.bias.grad, atol=1e-4)
def test_public_attrs(): class MyString: def __init__(self, value): self.value = value def __str__(self): return self.value model = nn.Sequential(nn.Linear(1, 1)) gpipe = GPipe(model, balance=(1, ), devices=('cpu', ), chunks=42.000, checkpoint=MyString('always')) assert gpipe.balance == [1] assert gpipe.devices == [torch.device('cpu')] assert gpipe.chunks == 42 assert isinstance(gpipe.chunks, int) assert gpipe.checkpoint == 'always' assert isinstance(gpipe.checkpoint, str)
def test_tuple_wait(cuda_sleep): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized # properly to the copy stream. class Sleep(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.detach() @staticmethod def backward(ctx, grad): with torch.cuda.device(grad.device): cuda_sleep(0.05) return grad class Layer1(nn.Module): def forward(self, pair): a, b = pair return a * 1, b * 2, b * 3 class Layer2(nn.Module): def forward(self, triple): a, b, c = triple b = Sleep.apply(b) return a + b + c model = nn.Sequential(Layer1(), Layer2()) model = GPipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint='never') a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) y.norm().backward() torch.cuda.synchronize(0) torch.cuda.synchronize(1) assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
def test_exception_no_hang(): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal # message. So the former partition was blocked at out_queue.join() for the # next of next micro-batch. class ExpectedException(Exception): pass class Pass(nn.Module): def forward(self, x): return x class Raise(nn.Module): def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) model = GPipe(model, [1, 1, 1], devices=['cpu', 'cpu', 'cpu'], chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3))
def test_exception_early_stop(): class ExpectedException(Exception): pass class Counter(nn.Module): def __init__(self): super().__init__() self.counter = 0 def forward(self, x): self.counter += 1 time.sleep(0.01) return x class Raise(nn.Module): def forward(self, x): raise ExpectedException() count_front = Counter() count_back = Counter() model = nn.Sequential(count_front, Raise(), count_back) model = GPipe(model, balance=[1, 1, 1], devices=['cpu', 'cpu', 'cpu'], chunks=1000) with pytest.raises(ExpectedException): model(torch.rand(1000, 1)) # This test is flaky because it relies on different speed among two partitions. # But to fail this test, the time to get an exception should be later than # 10 seconds (0.01 * 1000.) This situation doesn't seem to happen. count_front_counter = count_front.counter assert 1 <= count_front_counter < 1000 assert count_back.counter == 0 # The first partition should be already stopped. time.sleep(0.1) assert count_front.counter == count_front_counter
def test_input_pair(): class Two(nn.Module): def __init__(self): super().__init__() self.fc_a = nn.Linear(1, 1) self.fc_b = nn.Linear(1, 1) def forward(self, a_and_b): a, b = a_and_b return (self.fc_a(a), self.fc_b(b)) model = nn.Sequential(Two()) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) a_out, b_out = model((a, b)) loss = (a_out + b_out).mean() loss.backward() assert a.grad is not None assert b.grad is not None
def test_current_microbatch(): class Twice(nn.Module): def forward(self, x): return x * 2 class CurrentMicrobatch(nn.Module): def forward(self, _): return current_microbatch() # Not in a partition. assert current_microbatch() is None input = torch.tensor([1., 2., 3.]) model = nn.Sequential(Twice(), CurrentMicrobatch()) model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3) output = model(input) assert torch.allclose(output, torch.tensor([1., 2., 3.])) # Not in a partition. assert current_microbatch() is None
def test_exception_early_stop_asap(): """Even the first partitions have finished to process, the partition before the failed partition should be killed as soon as possible. """ class ExpectedException(Exception): pass class Pass(nn.Module): def forward(self, x): return x counter = 0 class Counter(nn.Module): def forward(self, x): time.sleep(0.1) nonlocal counter counter += 1 return x class Raise(nn.Module): def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) model = GPipe(model, [1, 1, 1, 1], devices=['cpu', 'cpu', 'cpu', 'cpu'], chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) # If the early stop doesn't work, it would be 3 instead. assert counter == 2
def test_none_skip(): @skippable(stash=['none']) class Stash(nn.Module): def forward(self, input): yield stash('none', None) return input @skippable(pop=['none']) class Pop(nn.Module): def forward(self, input): none = yield pop('none') assert none is None return input model = nn.Sequential(Stash(), Pop()) model = GPipe(model, [1, 1], devices=['cpu', 'cpu'], chunks=5) input = torch.rand(10, requires_grad=True) output = model(input) def assert_grad_fn_is_not_portal(grad_fn, visited=set()): if grad_fn in visited or grad_fn is None: return assert not isinstance(grad_fn, PortalBlue._backward_cls) assert not isinstance(grad_fn, PortalCopy._backward_cls) assert not isinstance(grad_fn, PortalOrange._backward_cls) visited.add(grad_fn) for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) assert_grad_fn_is_not_portal(output.grad_fn) output.sum().backward() assert input.grad.mean().item() == 1
def test_checkpoint_eval(): model = nn.Sequential(nn.Linear(1, 1)) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): if grad_fn is None: return False if grad_fn.__class__.__name__ == name: return True for next_grad_fn, _ in grad_fn.next_functions: if find_grad_fn(next_grad_fn, name): return True return False model.train() train_output = model(input) assert find_grad_fn(train_output.grad_fn, 'CheckpointBackward') assert find_grad_fn(train_output.grad_fn, 'RecomputeBackward') model.eval() eval_output = model(input) assert not find_grad_fn(eval_output.grad_fn, 'CheckpointBackward') assert not find_grad_fn(eval_output.grad_fn, 'RecomputeBackward')
def main(): parser = argparse.ArgumentParser(description='D-DNN imagenet benchmark') parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') # Value of args.synthetic_data may seem confusing, but those values # come from bash and there 0=true and all else =false parser.add_argument('-s', '--synthetic_data', type=int, default=0, help="Use synthetic data") args = parser.parse_args() torch.manual_seed(1) torch.cuda.manual_seed(1) cudnn.benchmark = True #--------------------------------------------------------------------------------- # Move model to GPU. print("=> creating model '{}'".format(args.arch)) model = model_names[args.arch].cuda() partitions = torch.cuda.device_count() if args.synthetic_data == -1: sample = torch.empty(batch_size, 3, 512, 512) else: sample = torch.empty(batch_size, 3, 224, 224) balance = balance_by_time(partitions, model, sample) model = GPipe(model, balance, chunks=microbatches) #--------------------------------------------------------------------------------- devices = list(model.devices) in_device = devices[0] out_device = devices[-1] torch.cuda.set_device(in_device) throughputs = [] elapsed_times = [] #--------------------------------------------------------------------------------- # define optimizer optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #--------------------------------------------------------------------------------- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_comp = [ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ] val_comp = [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ] if args.synthetic_data == -1: # Load highres data traindir = datadir + '/HIGHRES/train' valdir = datadir + '/HIGHRES/val' train_comp = [transforms.ToTensor(), normalize] val_comp = [transforms.ToTensor(), normalize] elif args.synthetic_data: # Load normal data traindir = datadir + '/train' valdir = datadir + '/val' else: # Load synthetic data traindir = datadir + '/IMAGENET/train' valdir = datadir + '/IMAGENET/val' train_loader = torch.utils.data.DataLoader(datasets.ImageFolder( traindir, transforms.Compose(train_comp)), batch_size=batch_size, shuffle=True, num_workers=cores_gpu, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose(val_comp)), batch_size=batch_size, shuffle=True, num_workers=cores_gpu, pin_memory=True) #--------------------------------------------------------------------------------- for epoch in range(epochs): throughput, elapsed_time = run_epoch(train_loader, val_loader, model, optimizer, epoch, args, in_device, out_device) throughputs.append(throughput) elapsed_times.append(elapsed_time) _, valid_accuracy = evaluate(val_loader, model, args, in_device, out_device) n = len(throughputs) throughput = sum(throughputs) / n if n > 0 else 0.0 elapsed_time = sum(elapsed_times) / n if n > 0 else 0.0 print('valid accuracy: %.4f | %.3f samples/sec, %.3f sec/epoch (average)' '' % (valid_accuracy, throughput, elapsed_time))