Example #1
0
def test_empty_module():
    # Empty sequential module is not illegal.
    model = nn.Sequential()
    model = GPipe(model, [])

    assert model(torch.tensor(42)) == torch.tensor(42)
    assert model((torch.tensor(42), )) == (torch.tensor(42), )

    # But only tensor or tensors is legal in GPipe.
    with pytest.raises(TypeError):
        model(42)
Example #2
0
def test_inplace_on_not_requires_grad():
    # In-place operation on a tensor not requiring grad doesn't cause a
    # RuntimeError. Currently, we cannot detect this case.
    model = nn.Sequential(nn.ReLU(inplace=True))
    model = GPipe(model, [1], devices=['cpu'], checkpoint='always')

    x = torch.rand(1)
    y = model(x)

    message = 'a leaf Variable that requires grad has been used in an in-place operation.'
    with pytest.raises(RuntimeError, match=message):
        y.backward()
Example #3
0
def test_checkpoint_mode_invalid():
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(
            ValueError,
            match="checkpoint is not one of 'always', 'except_last', or 'never'"
    ):
        GPipe(model,
              balance=[1],
              devices=['cpu'],
              chunks=2,
              checkpoint='INVALID_CHECKPOINT')
Example #4
0
    def naive2(model: nn.Module, devices: List[int]) -> Stuffs:
        batch_size = 47
        balance = [84, 241]

        model = cast(nn.Sequential, model)
        # GPipe with chunks=1, checkpoint='never' is equivalent to a typical model parallel.
        model = GPipe(model,
                      balance,
                      devices=devices,
                      chunks=1,
                      checkpoint='never')
        return model, batch_size, list(model.devices)
Example #5
0
def test_input_singleton():
    class One(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)

        def forward(self, only_a):
            a, = only_a
            return (self.fc(a), )

    model = nn.Sequential(One())
    model = GPipe(model, balance=[1], devices=['cpu'], chunks=2)

    a = torch.rand(10, 1, requires_grad=True)

    a_out, = model((a, ))
    loss = a_out.mean()
    loss.backward()

    assert all(p.grad is not None for p in model.parameters())
    assert a.grad is not None
Example #6
0
    def pipeline1(model: nn.Module, devices: List[int]) -> Stuffs:
        batch_size = 96
        chunks = 1
        balance = [370]

        model = cast(nn.Sequential, model)
        model = GPipe(model,
                      balance,
                      devices=devices,
                      chunks=chunks,
                      checkpoint='always')
        return model, batch_size, list(model.devices)
Example #7
0
    def pipeline8(devices: List[int]) -> Stuffs:
        B, C = 48, 160
        balance = [800, 140, 62, 36, 36, 36, 36, 987]

        model: nn.Module = unet(depth=5,
                                num_convs=B,
                                base_channels=C,
                                input_channels=3,
                                output_channels=1)
        model = cast(nn.Sequential, model)
        model = GPipe(model, balance, devices=devices, chunks=128)

        return model, B, C, list(model.devices)
Example #8
0
def test_exception():
    class ExpectedException(Exception):
        pass

    class Raise(nn.Module):
        def forward(self, *_):
            raise ExpectedException()

    model = nn.Sequential(Raise())
    model = GPipe(model, balance=[1], devices=['cpu'], chunks=1)

    with pytest.raises(ExpectedException):
        model(torch.rand(1))
Example #9
0
    def pipeline4(devices: List[int]) -> Stuffs:
        B, C = 24, 160
        balance = [472, 54, 36, 515]

        model: nn.Module = unet(depth=5,
                                num_convs=B,
                                base_channels=C,
                                input_channels=3,
                                output_channels=1)
        model = cast(nn.Sequential, model)
        model = GPipe(model, balance, devices=devices, chunks=32)

        return model, B, C, list(model.devices)
Example #10
0
def test_1to3(balance, checkpoint):
    if torch.cuda.device_count() < len(balance):
        pytest.skip('at least %d cuda devices required' % len(balance))

    @skippable(stash=['1to3'])
    class Layer1(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            yield stash('1to3', input)
            output = self.conv(input)
            return output

    class Layer2(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            output = self.conv(input)
            return output

    @skippable(pop=['1to3'])
    class Layer3(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            skip_1to3 = yield pop('1to3')
            output = self.conv(input) + skip_1to3
            return output

    model = nn.Sequential(Layer1(), Layer2(), Layer3())
    model = GPipe(model, balance, chunks=3, checkpoint=checkpoint)

    in_device = model.devices[0]
    out_device = model.devices[-1]

    input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
    output = model(input)
    loss = output.mean()
    loss.backward()

    assert torch.allclose(output.norm(),
                          torch.tensor(1039.159, device=out_device))
    assert torch.allclose(input.grad.norm(),
                          torch.tensor(0.0004533053, device=in_device))
Example #11
0
def test_no_grad():
    model = nn.Sequential(nn.Linear(1, 1))
    model = GPipe(model, balance=[1], devices=['cpu'], chunks=2)
    input = torch.rand(2, 1)

    latent = None

    def hook(module, input, outputs):
        _ = module
        _ = input

        output, _ = outputs

        nonlocal latent
        latent = output

    partition = list(model.partitions())[0]
    partition.register_forward_hook(hook)

    with torch.no_grad():
        model(input)

    assert latent.grad_fn is None
Example #12
0
def test_devices():
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)
    c = nn.Linear(1, 1)

    # There are extra two devices.
    devices = ['cpu', 'cpu', 'cpu', 'cpu', 'cpu']

    model = nn.Sequential(a, b, c)
    model = GPipe(model, [1, 1, 1], devices=devices)

    cpu = torch.device('cpu')
    # Extra devices must be discarded.
    assert model.devices == (cpu, cpu, cpu)
Example #13
0
def test_identicalness():
    def sum_grad(parameters):
        return sum([p.grad.sum() for p in parameters if p.grad is not None])

    def zero_grad(parameters):
        for p in parameters:
            p.grad = None

    inputs = torch.rand(8, 1)
    model = nn.Sequential(
        nn.Linear(1, 2),
        nn.Linear(2, 4),
        nn.Linear(4, 2),
        nn.Linear(2, 1),
    )

    # Without GPipe
    outputs = model(inputs)
    loss = outputs.mean()
    loss.backward()

    grad_without_gpipe = sum_grad(model.parameters())

    zero_grad(model.parameters())

    # With GPipe
    model = GPipe(model, [2, 2], devices=['cpu', 'cpu'], chunks=4)

    outputs = model(inputs)
    loss = outputs.mean()
    loss.backward()

    grad_with_gpipe = sum_grad(model.parameters())

    # Both grads should be identical.
    assert torch.allclose(grad_with_gpipe, grad_without_gpipe)
Example #14
0
def _gpipe(
    model: nn.Module,
    devices: List[int],
    batch_size: int,
    chunks: int,
    balance: List[int],
    checkpoint: str,
) -> Stuffs:
    model = cast(nn.Sequential, model)
    model = GPipe(model,
                  balance,
                  devices=devices,
                  chunks=chunks,
                  checkpoint=checkpoint)
    return model, batch_size, list(model.devices)
Example #15
0
def test_non_tensor_tuple():
    class NonTensorTuple(nn.Module):
        def forward(self, x):
            return (x, 'hello')

    model = nn.Sequential(NonTensorTuple())
    model = GPipe(model, balance=[1], devices=['cpu'])
    x = torch.rand(1)

    # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
    with pytest.raises(TypeError):
        model(x)

    # TypeError: expected Tensor to scatter, but got str
    with pytest.raises(TypeError):
        model((x, 'hello'))
Example #16
0
def test_non_tensor():
    class NonTensor(nn.Module):
        def forward(self, _):
            return 'hello'

    model = nn.Sequential(NonTensor())
    model = GPipe(model, balance=[1], devices=['cpu'])
    x = torch.rand(1)

    # TypeError: expected Tensor as element 0 in argument 0, but got str
    with pytest.raises(TypeError):
        model(x)

    # TypeError: expected Tensor to scatter, but got str
    with pytest.raises(TypeError):
        model('hello')
Example #17
0
def test_deferred_batch_norm(checkpoint):
    bn = nn.BatchNorm2d(3)
    gpipe_bn = deepcopy(bn)
    gpipe = GPipe(nn.Sequential(gpipe_bn),
                  balance=[1],
                  devices=['cpu'],
                  chunks=2,
                  checkpoint=checkpoint,
                  deferred_batch_norm=True)

    x = torch.rand(4, 3, 10, 10)
    gpipe(x).mean().backward()
    bn(x).mean().backward()

    assert torch.allclose(gpipe[0].running_mean, bn.running_mean, atol=1e-4)
    assert torch.allclose(gpipe[0].running_var, bn.running_var, atol=1e-4)
Example #18
0
def test_sequential_like(balance):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = GPipe(model, balance, devices=['cpu', 'cpu'])

    assert len(model) == 2
    assert list(model) == [a, b]

    assert model[0] is a
    assert model[1] is b
    with pytest.raises(IndexError):
        _ = model[2]

    assert model[-1] is b
    assert model[-2] is a
Example #19
0
def test_parallel_randoms():
    class Dropouts(nn.Module):
        def forward(self, x):
            for _ in range(100):
                x = F.dropout(x, p=0.001)
            return x

    model = nn.Sequential(Dropouts(), Dropouts())

    x = torch.rand(10, 10, requires_grad=True)
    model = GPipe(model, [1, 1],
                  devices=['cpu', 'cpu'],
                  chunks=10,
                  checkpoint='always')
    y = model(x)
    y.norm().backward()

    assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
Example #20
0
def test_deferred_batch_norm_params(checkpoint):
    bn = nn.BatchNorm2d(3)
    gpipe_bn = deepcopy(bn)
    gpipe = GPipe(nn.Sequential(gpipe_bn),
                  balance=[1],
                  devices=['cpu'],
                  chunks=1,
                  checkpoint=checkpoint,
                  deferred_batch_norm=True)

    x = torch.rand(4, 3, 10, 10)
    gpipe(x).mean().backward()
    bn(x).mean().backward()

    assert gpipe[0].weight.grad is not None
    assert gpipe[0].bias.grad is not None

    assert torch.allclose(gpipe[0].weight.grad, bn.weight.grad, atol=1e-4)
    assert torch.allclose(gpipe[0].bias.grad, bn.bias.grad, atol=1e-4)
Example #21
0
def test_public_attrs():
    class MyString:
        def __init__(self, value):
            self.value = value

        def __str__(self):
            return self.value

    model = nn.Sequential(nn.Linear(1, 1))
    gpipe = GPipe(model,
                  balance=(1, ),
                  devices=('cpu', ),
                  chunks=42.000,
                  checkpoint=MyString('always'))

    assert gpipe.balance == [1]
    assert gpipe.devices == [torch.device('cpu')]
    assert gpipe.chunks == 42
    assert isinstance(gpipe.chunks, int)
    assert gpipe.checkpoint == 'always'
    assert isinstance(gpipe.checkpoint, str)
Example #22
0
def test_tuple_wait(cuda_sleep):
    # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
    # Under this behavior, if checkpointing was disabled, there's a possibility
    # that gradient accumulations on other tensors are not synchronized
    # properly to the copy stream.
    class Sleep(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            return x.detach()

        @staticmethod
        def backward(ctx, grad):
            with torch.cuda.device(grad.device):
                cuda_sleep(0.05)
            return grad

    class Layer1(nn.Module):
        def forward(self, pair):
            a, b = pair
            return a * 1, b * 2, b * 3

    class Layer2(nn.Module):
        def forward(self, triple):
            a, b, c = triple
            b = Sleep.apply(b)
            return a + b + c

    model = nn.Sequential(Layer1(), Layer2())
    model = GPipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint='never')

    a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
    b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)

    y = model((a, b))
    y.norm().backward()

    torch.cuda.synchronize(0)
    torch.cuda.synchronize(1)

    assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
Example #23
0
def test_exception_no_hang():
    # In v0.0.2, once a failed partition receives a normal message
    # (non-closing) for the next micro-batch, a hang occured. The reason was
    # that a failed partition didn't call in_queue.task_done() on a normal
    # message. So the former partition was blocked at out_queue.join() for the
    # next of next micro-batch.
    class ExpectedException(Exception):
        pass

    class Pass(nn.Module):
        def forward(self, x):
            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    model = nn.Sequential(Pass(), Pass(), Raise())
    model = GPipe(model, [1, 1, 1], devices=['cpu', 'cpu', 'cpu'], chunks=3)

    with pytest.raises(ExpectedException):
        model(torch.rand(3))
Example #24
0
def test_exception_early_stop():
    class ExpectedException(Exception):
        pass

    class Counter(nn.Module):
        def __init__(self):
            super().__init__()
            self.counter = 0

        def forward(self, x):
            self.counter += 1
            time.sleep(0.01)
            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    count_front = Counter()
    count_back = Counter()
    model = nn.Sequential(count_front, Raise(), count_back)
    model = GPipe(model,
                  balance=[1, 1, 1],
                  devices=['cpu', 'cpu', 'cpu'],
                  chunks=1000)

    with pytest.raises(ExpectedException):
        model(torch.rand(1000, 1))

    # This test is flaky because it relies on different speed among two partitions.
    # But to fail this test, the time to get an exception should be later than
    # 10 seconds (0.01 * 1000.) This situation doesn't seem to happen.
    count_front_counter = count_front.counter
    assert 1 <= count_front_counter < 1000
    assert count_back.counter == 0

    # The first partition should be already stopped.
    time.sleep(0.1)
    assert count_front.counter == count_front_counter
Example #25
0
def test_input_pair():
    class Two(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_a = nn.Linear(1, 1)
            self.fc_b = nn.Linear(1, 1)

        def forward(self, a_and_b):
            a, b = a_and_b
            return (self.fc_a(a), self.fc_b(b))

    model = nn.Sequential(Two())
    model = GPipe(model, balance=[1], devices=['cpu'], chunks=2)

    a = torch.rand(10, 1, requires_grad=True)
    b = torch.rand(10, 1, requires_grad=True)

    a_out, b_out = model((a, b))
    loss = (a_out + b_out).mean()
    loss.backward()

    assert a.grad is not None
    assert b.grad is not None
Example #26
0
def test_current_microbatch():
    class Twice(nn.Module):
        def forward(self, x):
            return x * 2

    class CurrentMicrobatch(nn.Module):
        def forward(self, _):
            return current_microbatch()

    # Not in a partition.
    assert current_microbatch() is None

    input = torch.tensor([1., 2., 3.])

    model = nn.Sequential(Twice(), CurrentMicrobatch())
    model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3)

    output = model(input)

    assert torch.allclose(output, torch.tensor([1., 2., 3.]))

    # Not in a partition.
    assert current_microbatch() is None
Example #27
0
def test_exception_early_stop_asap():
    """Even the first partitions have finished to process, the partition before
    the failed partition should be killed as soon as possible.
    """
    class ExpectedException(Exception):
        pass

    class Pass(nn.Module):
        def forward(self, x):
            return x

    counter = 0

    class Counter(nn.Module):
        def forward(self, x):
            time.sleep(0.1)

            nonlocal counter
            counter += 1

            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
    model = GPipe(model, [1, 1, 1, 1],
                  devices=['cpu', 'cpu', 'cpu', 'cpu'],
                  chunks=3)

    with pytest.raises(ExpectedException):
        model(torch.rand(3))

    # If the early stop doesn't work, it would be 3 instead.
    assert counter == 2
Example #28
0
def test_none_skip():
    @skippable(stash=['none'])
    class Stash(nn.Module):
        def forward(self, input):
            yield stash('none', None)
            return input

    @skippable(pop=['none'])
    class Pop(nn.Module):
        def forward(self, input):
            none = yield pop('none')
            assert none is None
            return input

    model = nn.Sequential(Stash(), Pop())
    model = GPipe(model, [1, 1], devices=['cpu', 'cpu'], chunks=5)

    input = torch.rand(10, requires_grad=True)
    output = model(input)

    def assert_grad_fn_is_not_portal(grad_fn, visited=set()):
        if grad_fn in visited or grad_fn is None:
            return

        assert not isinstance(grad_fn, PortalBlue._backward_cls)
        assert not isinstance(grad_fn, PortalCopy._backward_cls)
        assert not isinstance(grad_fn, PortalOrange._backward_cls)

        visited.add(grad_fn)
        for next_grad_fn, _ in grad_fn.next_functions:
            assert_grad_fn_is_not_portal(next_grad_fn, visited)

    assert_grad_fn_is_not_portal(output.grad_fn)

    output.sum().backward()
    assert input.grad.mean().item() == 1
Example #29
0
def test_checkpoint_eval():
    model = nn.Sequential(nn.Linear(1, 1))
    model = GPipe(model, balance=[1], devices=['cpu'], chunks=2)
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, 'CheckpointBackward')
    assert find_grad_fn(train_output.grad_fn, 'RecomputeBackward')

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, 'CheckpointBackward')
    assert not find_grad_fn(eval_output.grad_fn, 'RecomputeBackward')
Example #30
0
def main():
    parser = argparse.ArgumentParser(description='D-DNN imagenet benchmark')
    parser.add_argument('-a',
                        '--arch',
                        metavar='ARCH',
                        default='resnet50',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                        ' (default: resnet50)')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.1,
                        type=float,
                        metavar='LR',
                        help='initial learning rate',
                        dest='lr')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        metavar='M',
                        help='momentum')
    parser.add_argument('--wd',
                        '--weight-decay',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    # Value of args.synthetic_data may seem confusing, but those values
    # come from bash and there 0=true and all else =false
    parser.add_argument('-s',
                        '--synthetic_data',
                        type=int,
                        default=0,
                        help="Use synthetic data")
    args = parser.parse_args()

    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    cudnn.benchmark = True

    #---------------------------------------------------------------------------------
    # Move model to GPU.
    print("=> creating model '{}'".format(args.arch))
    model = model_names[args.arch].cuda()

    partitions = torch.cuda.device_count()
    if args.synthetic_data == -1:
        sample = torch.empty(batch_size, 3, 512, 512)
    else:
        sample = torch.empty(batch_size, 3, 224, 224)
    balance = balance_by_time(partitions, model, sample)
    model = GPipe(model, balance, chunks=microbatches)

    #---------------------------------------------------------------------------------
    devices = list(model.devices)
    in_device = devices[0]
    out_device = devices[-1]
    torch.cuda.set_device(in_device)

    throughputs = []
    elapsed_times = []
    #---------------------------------------------------------------------------------

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    #---------------------------------------------------------------------------------
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_comp = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ]
    val_comp = [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ]

    if args.synthetic_data == -1:
        # Load highres data
        traindir = datadir + '/HIGHRES/train'
        valdir = datadir + '/HIGHRES/val'
        train_comp = [transforms.ToTensor(), normalize]
        val_comp = [transforms.ToTensor(), normalize]
    elif args.synthetic_data:
        # Load normal data
        traindir = datadir + '/train'
        valdir = datadir + '/val'
    else:
        # Load synthetic data
        traindir = datadir + '/IMAGENET/train'
        valdir = datadir + '/IMAGENET/val'

    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir, transforms.Compose(train_comp)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=cores_gpu,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, transforms.Compose(val_comp)),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=cores_gpu,
                                             pin_memory=True)
    #---------------------------------------------------------------------------------

    for epoch in range(epochs):
        throughput, elapsed_time = run_epoch(train_loader, val_loader, model,
                                             optimizer, epoch, args, in_device,
                                             out_device)

        throughputs.append(throughput)
        elapsed_times.append(elapsed_time)

    _, valid_accuracy = evaluate(val_loader, model, args, in_device,
                                 out_device)

    n = len(throughputs)
    throughput = sum(throughputs) / n if n > 0 else 0.0
    elapsed_time = sum(elapsed_times) / n if n > 0 else 0.0
    print('valid accuracy: %.4f | %.3f samples/sec, %.3f sec/epoch (average)'
          '' % (valid_accuracy, throughput, elapsed_time))