Beispiel #1
0
def run_mp_worker(args, available_workers):

    benchmark_config = create_benchmark_config(args.model_name)
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args,
                                       benchmark_config=benchmark_config,
                                       model_specs=model_specs)
    model = model_config["model"]

    balance = generate_balance(get_pipeline_parallel_group().size(),
                               len(model))
    pipe_model = MultiProcessPipe(
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
        input_device=torch.device("cuda")
        if torch.cuda.is_available() else torch.device("cpu"),
        checkpoint=args.checkpoint,
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
    )
    if torch.cuda.is_available():
        pipe_model = pipe_model.cuda()

    if args.dry_run:
        train(model_config, pipe_model, benchmark_config, model_specs, args)
    else:
        benchmark_language_model(model_config, pipe_model, benchmark_config,
                                 model_specs, args)
Beispiel #2
0
def checkpoint_non_float_input(pipeline_style):
    class ForkNonFloat(nn.Module):
        def forward(self, input):
            return (input * 2, torch.tensor([False]))

    class JoinNonFloat(nn.Module):
        def forward(self, input):
            return input[0] * 2

    model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
    model = MultiProcessPipe(
        model,
        balance=[1, 1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=1,
        checkpoint="always",
        pipelined_backward=False,
    )

    input = torch.rand(1, requires_grad=True)
    output = model(input)
    if model.group.rank() == 1:
        # with torch.autograd.detect_anomaly():
        output.backward()
    elif pipeline_style == MultiProcessPipe.MultiProcess:
        model.back_helper(output)

    torch.distributed.barrier()
Beispiel #3
0
def exception_no_hang(pipeline_style):
    # In v0.0.2, once a failed partition receives a normal message
    # (non-closing) for the next micro-batch, a hang occured. The reason was
    # that a failed partition didn't call in_queue.task_done() on a normal
    # message. So the former partition was blocked at out_queue.join() for the
    # next of next micro-batch.
    class ExpectedException(Exception):
        pass

    class Pass(nn.Module):
        def forward(self, x):
            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    model = nn.Sequential(Pass(), Pass(), Raise())
    model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
    model.eval()

    if model.group.rank() == 2:
        with pytest.raises(ExpectedException):
            model(torch.rand(3))
    else:
        model(torch.rand(3))

    torch.distributed.barrier()
Beispiel #4
0
def parallel_randoms(pipeline_style):
    class Dropouts(nn.Module):
        def forward(self, x):
            for _ in range(100):
                x = F.dropout(x, p=0.001)
            return x

    model = nn.Sequential(Dropouts(), Dropouts())

    x = torch.rand(10, 10, requires_grad=True).cuda()
    x.retain_grad()
    model = MultiProcessPipe(
        model,
        [1, 1],
        style=pipeline_style,
        input_device=torch.cuda.current_device(),
        worker_map=get_worker_map(),
        chunks=10,
        checkpoint="always",
    ).cuda()
    y = model(x)
    tensor_list = [torch.empty_like(x) for _ in range(2)]
    if model.group.rank() == 1:
        y.norm().backward()
        torch.distributed.barrier()
        tensor_list[model.group.rank()] = y
        torch.distributed.all_gather(tensor_list, y, group=model.group)
        assert tensor_list[0].to(torch.bool).tolist() == tensor_list[1].to(torch.bool).tolist()
    else:
        model.back_helper(y)
        torch.distributed.barrier()
        tensor_list[model.group.rank()] = x.grad
        torch.distributed.all_gather(tensor_list, x.grad, group=model.group)
Beispiel #5
0
def input_singleton(pipeline_style):
    class One(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)

        def forward(self, only_a):
            (a, ) = only_a
            return (self.fc(a), )

    model = nn.Sequential(One())
    model = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        pipelined_backward=False,
    )

    a = torch.rand(10, 1, requires_grad=True)

    (a_out, ) = model((a, ))
    loss = a_out.mean()
    loss.backward()

    assert all(p.grad is not None for p in model.parameters())
    assert a.grad is not None
Beispiel #6
0
def run_mp_worker(args, available_workers):

    benchmark_config = create_benchmark_config(args.model_name)
    model_config = create_model_config(args, config=benchmark_config)
    model = model_config["model"]

    balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
    pipe_model = MultiProcessPipe(
        model,
        balance,
        style=MultiProcessPipe.AsyncSchedule,
        chunks=args.chunks,
        worker_map=get_worker_map(),
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
        pipelined_backward=args.pipelined_backward,
        checkpoint=args.checkpoint,
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
    )
    if torch.cuda.is_available():
        pipe_model = pipe_model.cuda()
    if args.all_at_once and pipe_model.pipeline:
        print(f"running all at once")
        pipe_model.pipeline.all_at_once = True

    if args.use_synthetic_data:
        train(model_config, pipe_model, benchmark_config, args)
    else:
        benchmark_language_model(model_config, pipe_model, benchmark_config, args)
Beispiel #7
0
def parameters(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))
    pipe = MultiProcessPipe(model,
                            balance=[1],
                            style=pipeline_style,
                            worker_map=get_worker_map(),
                            chunks=1)
    if torch.distributed.get_rank() == 0:
        assert list(pipe.parameters()) != []
    else:
        assert list(pipe.parameters()) == []
Beispiel #8
0
def simple_linears(pipeline_style):
    def sum_grad(parameters):
        return sum([p.grad.sum() for p in parameters if p.grad is not None])

    def zero_grad(parameters):
        for p in parameters:
            p.grad = None

    set_random_seed(12345)
    inputs = torch.rand(8, 1)
    model = nn.Sequential(
        nn.Linear(1, 2),
        nn.Linear(2, 4),
        nn.Linear(4, 2),
        nn.Linear(2, 1),
    )

    # Without MultiProcessPipe
    outputs = model(inputs)
    loss = outputs.mean()
    loss.backward()

    grad_without_pipe = [
        sum_grad([*model[0].parameters(), *model[1].parameters()]),
        sum_grad([*model[2].parameters(), *model[3].parameters()]),
    ]

    ref_without_pipe = [p.grad for p in model.parameters()]

    zero_grad(model.parameters())

    # With MultiProcessPipe
    model = MultiProcessPipe(model, [2, 2],
                             style=pipeline_style,
                             worker_map=get_worker_map(),
                             chunks=4)

    outputs = model(inputs)
    if model.group.rank() == 1:
        loss = outputs.mean()
        loss.backward()
        grad_with_pipe = sum_grad(
            model.pipeline.partitions[0].module.parameters())

        # Both grads should be identical.
        assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
    else:
        model.back_helper(outputs)
        grad_with_pipe = sum_grad(
            model.pipeline.partitions[0].module.parameters())

        # Both grads should be identical.
        assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
    torch.distributed.barrier()
Beispiel #9
0
def checkpoint_mode(pipeline_style):
    def count_grad_fn(grad_fn, name, visited=set()):
        if grad_fn in visited:
            return 0
        visited.add(grad_fn)

        if grad_fn is None:
            return 0
        if grad_fn.__class__.__name__ == name:
            return 1

        counter = 0
        for next_grad_fn, _ in grad_fn.next_functions:
            counter += count_grad_fn(next_grad_fn, name, visited=visited)
        return counter

    model = nn.Sequential(nn.Linear(1, 1))
    input = torch.rand(2, 1)

    always = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        checkpoint="always",
        pipelined_backward=False,
    )
    except_last = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        checkpoint="except_last",
        pipelined_backward=False,
    )
    never = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        checkpoint="never",
        pipelined_backward=False,
    )

    always_output = always(input)
    except_last_output = except_last(input)
    never_output = never(input)

    assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2
    assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1
    assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0
Beispiel #10
0
def recommend_auto_balance(pipeline_style):
    with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
        # balance is required
        MultiProcessPipe(nn.Sequential())

    with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
        # module and sum of balance have differen length (module: 0, sum of balance: 1)
        MultiProcessPipe(nn.Sequential(), [1])

    with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
        # module and sum of balance have different length (module: 2, sum of balance: 1)
        MultiProcessPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
Beispiel #11
0
def none_skip(pipeline_style):
    if pipeline_style == MultiProcessPipe.AsyncSchedule:
        pytest.skip("Skip tensors NYI for AsyncSchedule")

    @skippable(stash=["none"])
    class Stash(nn.Module):
        def forward(self, input):
            yield stash("none", None)
            return input

    @skippable(pop=["none"])
    class Pop(nn.Module):
        def forward(self, input):
            none = yield pop("none")
            assert none is None
            return input

    model = nn.Sequential(Stash(), Pop())
    model = MultiProcessPipe(
        model,
        [1, 1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        input_device=torch.cuda.current_device(),
        chunks=5,
    ).cuda()

    input = torch.rand(10, requires_grad=True).cuda()
    input.retain_grad()
    output = model(input)

    def assert_grad_fn_is_not_portal(grad_fn, visited=set()):
        if grad_fn in visited or grad_fn is None:
            return

        assert not isinstance(grad_fn, PortalBlue._backward_cls)
        assert not isinstance(grad_fn, PortalCopy._backward_cls)
        assert not isinstance(grad_fn, PortalOrange._backward_cls)

        visited.add(grad_fn)
        for next_grad_fn, _ in grad_fn.next_functions:
            assert_grad_fn_is_not_portal(next_grad_fn, visited)

    if model.group.rank() == 1:
        assert_grad_fn_is_not_portal(output.grad_fn)

        output.sum().backward()
    else:
        model.back_helper(output)
        assert input.grad.mean().item() == 1
def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
    dist_init(rank, world_size)
    os.environ["MASTER_PORT"] = "10639"
    dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
    initialize_model_parallel(1, world_size)

    model = get_model()
    data, target = get_data()[0]
    loss_fn = get_loss_fun()

    device = torch.device("cuda",
                          rank) if DEVICE == "cuda" else torch.device("cpu")

    model = MultiProcessPipe(
        model,
        balance=[2, 1],
        style=MultiProcessPipe.MultiProcess,
        worker_map={
            0: "worker0",
            1: "worker1"
        },  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")
    dist.rpc.shutdown()

    del model
Beispiel #13
0
def tuple_wait(cuda_sleep, pipeline_style):
    # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
    # Under this behavior, if checkpointing was disabled, there's a possibility
    # that gradient accumulations on other tensors are not synchronized
    # properly to the copy stream.
    class Sleep(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            return x.detach()

        @staticmethod
        def backward(ctx, grad):
            with torch.cuda.device(grad.device):
                cuda_sleep(0.05)
            return grad

    class Layer1(nn.Module):
        def forward(self, pair):
            a, b = pair
            return a * 1, b * 2, b * 3

    class Layer2(nn.Module):
        def forward(self, triple):
            a, b, c = triple
            b = Sleep.apply(b)
            return a + b + c

    model = nn.Sequential(Layer1(), Layer2())
    model = MultiProcessPipe(
        model,
        [1, 1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        input_device=torch.cuda.current_device(),
        chunks=32,
        checkpoint="never",
    ).cuda()

    a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
    b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)

    y = model((a, b))
    if model.group.rank() == 1:
        y.norm().backward()
    else:
        model.back_helper(y)

    if model.group.rank() == 0:
        assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
Beispiel #14
0
def inplace_incorrect_grad(pipeline_style):
    class M(nn.Module):
        def forward(self, foo_bar):
            # 'foo' requires grad but 'bar' does not. In-place operation on
            # 'bar' won't cause a RuntimeError.
            foo, bar = foo_bar

            # add_(1) is not idempotent, in contrast to relu_(). If it is
            # executed multiple times, it will accumulates each difference onto
            # 'bar'.
            bar.add_(1)

            # 'bar' is still captured by checkpointing. 'foo' will get
            # incorrect grad.
            return foo * bar

    model = nn.Sequential(M())
    model = MultiProcessPipe(model, [1],
                             style=pipeline_style,
                             worker_map=get_worker_map(),
                             checkpoint="always")

    foo = torch.tensor([1.0], requires_grad=True)
    bar = torch.tensor([1.0])

    output = model((foo, bar))
    del model
    output.backward()

    # The gradient of 'foo' should be 2, but it is 3 actually because
    # bar.add_(1) was executed twice due to checkpointing.
    assert foo.grad.item() == 2.0
Beispiel #15
0
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe_fn = lambda: pipe_bn  # noqa: E731
    if lazy:
        model = [LazyModule(pipe_fn)]
    else:
        model = nn.Sequential(pipe_bn)
    pipe = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=1,
        checkpoint=checkpoint,
        deferred_batch_norm=True,
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).mean().backward()
    bn(x).mean().backward()

    assert pipe[0].weight.grad is not None
    assert pipe[0].bias.grad is not None

    assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
    assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
Beispiel #16
0
def lazy_skippable_error(pipeline_style):
    """Using skippable layers in combination with lazy construction is currently
    not supported, check that it raises an Exception"""
    @skippable(stash=["1to3"])
    class Layer1(nn.Linear):
        pass

    @skippable(pop=["1to3"])
    class Layer3(nn.Linear):
        pass

    model = [
        LazyModule(lambda: Layer1(10, 10)),
        LazyModule(lambda: nn.Linear(10, 10)),
        LazyModule(lambda: Layer3(10, 10)),
    ]

    with pytest.raises(
            ValueError,
            match=
            "Can't use Skippable layers with multi-process pipe and lazy construction"
    ):
        MultiProcessPipe(
            model,
            [2, 1],
            style=pipeline_style,
            worker_map=get_worker_map(),
        )
Beispiel #17
0
def lazy_construction(pipeline_style):
    init_count = 0

    class Custom(nn.Module):
        def __init__(self):
            super(Custom, self).__init__()
            nonlocal init_count
            init_count += 1

        def forward(self, x):
            return x

    model = [
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
    ]

    pipe = MultiProcessPipe(model,
                            balance=[2, 2],
                            style=pipeline_style,
                            worker_map=get_worker_map())

    assert isinstance(pipe[0], Custom)
    assert isinstance(pipe[1], Custom)
    assert len(pipe) == 2
    assert init_count == 2
Beispiel #18
0
def chunks_less_than_1(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[1],
                         style=pipeline_style,
                         worker_map=get_worker_map(),
                         chunks=0)

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[1],
                         style=pipeline_style,
                         worker_map=get_worker_map(),
                         chunks=-1)
Beispiel #19
0
def input_pair(pipeline_style):
    class Two(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_a = nn.Linear(1, 1)
            self.fc_b = nn.Linear(1, 1)

        def forward(self, a_and_b):
            a, b = a_and_b
            return (self.fc_a(a), self.fc_b(b))

    model = nn.Sequential(Two())
    model = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        pipelined_backward=False,
    )

    a = torch.rand(10, 1, requires_grad=True)
    b = torch.rand(10, 1, requires_grad=True)

    a_out, b_out = model((a, b))
    loss = (a_out + b_out).mean()
    loss.backward()

    assert a.grad is not None
    assert b.grad is not None
Beispiel #20
0
def no_grad(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))
    model = MultiProcessPipe(model,
                             balance=[1],
                             style=pipeline_style,
                             worker_map=get_worker_map(),
                             chunks=2)
    input = torch.rand(2, 1)

    latent = None

    def hook(module, input, output):
        _ = module
        _ = input

        nonlocal latent
        latent = output

    partition = model.partitions[0]
    partition.module.register_forward_hook(hook)

    with torch.no_grad():
        model(input)

    assert latent.grad_fn is None
Beispiel #21
0
def balance_less_than_1(pipeline_style):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[0, 2],
                         style=pipeline_style,
                         worker_map=get_worker_map())

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[-1, 3],
                         style=pipeline_style,
                         worker_map=get_worker_map())
Beispiel #22
0
def balance_wrong_length(pipeline_style):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[1],
                         style=pipeline_style,
                         worker_map=get_worker_map())

    with pytest.raises(ValueError):
        MultiProcessPipe(model,
                         balance=[3],
                         style=pipeline_style,
                         worker_map=get_worker_map())
Beispiel #23
0
def partitions(pipeline_style):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = MultiProcessPipe(model, [1, 1],
                             style=pipeline_style,
                             worker_map=get_worker_map())

    assert isinstance(model.partitions, list)
    assert len(model) == 1
    assert isinstance(model.partitions[0].module, nn.Sequential)

    if model.group.rank() == 0:
        assert "0.0.weight" in model.state_dict()
    else:
        assert "0.1.weight" in model.state_dict()
Beispiel #24
0
def pipelined_backward(pipeline_style):
    model = nn.Sequential(nn.ReLU(), nn.ReLU())

    destroy_model_parallel()
    initialize_model_parallel(1, 4)
    pipe = MultiProcessPipe(model, [1, 1],
                            style=pipeline_style,
                            worker_map=get_worker_map())

    assert pipe.pipelined_backward is False

    destroy_model_parallel()
    initialize_model_parallel(2, 2)
    pipe = MultiProcessPipe(model, [1, 1],
                            style=pipeline_style,
                            worker_map=get_worker_map())

    assert pipe.pipelined_backward is True
Beispiel #25
0
def too_few_devices(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1),
                          nn.Linear(1, 1))

    with pytest.raises(IndexError):
        # len(balance) > len(group.size())
        model = MultiProcessPipe(model,
                                 balance=[1, 1, 1, 1],
                                 style=pipeline_style,
                                 worker_map=get_worker_map())
Beispiel #26
0
def named_children(pipeline_style):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
    model = MultiProcessPipe(model, [1, 1],
                             style=pipeline_style,
                             worker_map=get_worker_map())

    names = set(n for n, _ in model.named_modules())
    if model.group.rank() == 0:
        assert "0.a" in names
    else:
        assert "0.b" in names

    # MultiProcessPipe doesn't support __getattr__. Unlike nn.Sequential, MultiProcessPipe requires
    # several methods in its namespace.
    with pytest.raises(AttributeError):
        model.a
Beispiel #27
0
def python_autograd_function(pipeline_style):
    # FIXME deadlock with MultiProcessPipe.AsyncSchedule?
    # A Python autograd function might fail with this error:
    #
    #   RuntimeError: Returning Variables sharing storage with other Variables
    #   that require grad is not supported in Python functions. Please submit a
    #   feature request if you hit this error.
    #
    # It doesn't look like an essential restriction. But it happens on the
    # current PyTorch version. To avoid it, we should detach the tensor before
    # returning by identity autograd functions, such as Wait, Fork, and Join.

    torch.manual_seed(0)

    class Identity(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            return input

        @staticmethod
        def backward(ctx, grad):
            return grad

    class M(nn.Module):
        def forward(self, input):
            return Identity.apply(input)

    model = nn.Sequential(M(), M())
    model = MultiProcessPipe(
        model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
    ).cuda()
    model.eval()

    x = torch.rand(42)
    y = model(x)
    if model.group.rank() == 1:
        assert torch.allclose(x, y)

    torch.distributed.rpc.shutdown()
    torch.distributed.barrier()
Beispiel #28
0
def checkpoint_eval(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))
    model = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=2,
        pipelined_backward=False,
    )
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
Beispiel #29
0
def batch_size_small(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))
    model = MultiProcessPipe(model,
                             balance=[1],
                             style=pipeline_style,
                             worker_map=get_worker_map(),
                             chunks=4)

    with pytest.warns(None) as record:
        model(torch.rand(2, 1))

    # Batch size smaller than chunks is legal.
    assert not record
Beispiel #30
0
def input_varargs(pipeline_style):
    model = nn.Sequential(nn.Linear(1, 1))
    model = MultiProcessPipe(model,
                             balance=[1],
                             style=pipeline_style,
                             worker_map=get_worker_map())

    a = torch.rand(1)
    b = torch.rand(1)

    # TypeError: forward() takes 2 positional arguments but 3 were given
    with pytest.raises(TypeError):
        model(a, b)