예제 #1
0
def checkpoint_eval():
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(
        model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
    )
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
예제 #2
0
def test_checkpoint_eval():
    model = nn.Sequential(nn.Linear(1, 1))
    model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
예제 #3
0
def reuse_lazy():
    if False:  # speed
        reused = LazyModule(lambda: nn.Linear(10, 10))
        model = [
            reused,
            nn.Linear(10, 10),
            nn.ReLU(), reused,
            nn.ReLU(), reused,
            nn.ReLU()
        ]
        # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
        pipe = Pipe(model, [3, 1, 1],
                    style=Pipe.AsyncSchedule,
                    worker_map=get_worker_map())
        pipe.eval()
        output = pipe(torch.rand(10))

        print(f"output on {pipe.group.rank()}, {output}")
        torch.distributed.barrier()

    set_random_seed(1234)
    # test both foward
    reused = nn.Linear(10, 10)
    layers = [
        reused,
        nn.Linear(10, 10),
        nn.ReLU(), reused,
        nn.ReLU(), reused,
        nn.ReLU()
    ]
    model = nn.Sequential(*layers)
    model.eval()

    set_random_seed(1234)
    # ensure identical weights but no sharing between model and pipe
    reused = nn.Linear(10, 10)
    layers = [
        reused,
        nn.Linear(10, 10),
        nn.ReLU(), reused,
        nn.ReLU(), reused,
        nn.ReLU()
    ]
    pipe = Pipe(layers, [3, 1, 1],
                style=Pipe.AsyncSchedule,
                worker_map=get_worker_map())
    pipe.eval()
    model_optimizer = torch.optim.SGD(model.parameters(),
                                      lr=0.01,
                                      momentum=0.9)
    pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01,
                                     momentum=0.9) if len(
                                         list(pipe.parameters())) else None
    inputs = torch.rand(10)
    if False:  # speed
        model_out = model(inputs)
        pipe_out = pipe(inputs)

        torch.distributed.barrier()

        if pipe.final_stage:
            assert torch.equal(model_out, pipe_out)

    model.train()
    pipe.train()
    model_out = model(inputs)
    pipe_out = pipe(inputs)
    if pipe.final_stage:
        pipe_loss = pipe_out.mean()
        pipe_loss.backward()

    model_loss = model_out.mean()
    model_loss.backward()

    model_optimizer.step()
    if pipe_optimizer:
        pipe_optimizer.step()

    model.eval()
    pipe.eval()
    model_out = model(inputs)
    pipe_out = pipe(inputs)

    print(f"before barrier on {torch.distributed.get_rank()}")
    torch.distributed.barrier()
    print(f"after barrier on {torch.distributed.get_rank()}")

    if pipe.final_stage:
        assert torch.equal(model_out, pipe_out)
예제 #4
0
def test_delete_portal_tensor(train, checkpoint):
    # Without checkpointing:
    # +- Stash --+  +--- Pop ----+ - - - layers
    # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
    # +----------+  +------------+
    #
    # With checkpointing:
    # +- Stash --+  +--- Pop ----+  +--- Pop'----+  +- Stash'--+
    # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
    # +----------+  +------------+  +------------+  +----------+

    def portal_tensor_life_is(tensor_life, skip_tracker=None):
        if skip_tracker is None:
            skip_tracker = current_skip_tracker()

        # Get the current portal.
        portal = list(skip_tracker.portals.values())[0]

        if tensor_life == 0:
            return portal.tensor_life == 0 and portal.tensor is None
        else:
            return portal.tensor_life == tensor_life and portal.tensor is not None

    # Check the portal tensor after 'Stash'.
    stash_ = Stash()

    @stash_.register_forward_hook
    def check_portal_tensor_after_stash(*_):
        if is_checkpointing():
            assert portal_tensor_life_is(2)
        elif is_recomputing():
            assert portal_tensor_life_is(0)
        else:
            assert portal_tensor_life_is(1)

    pop_ = Pop()

    @pop_.register_forward_hook
    def check_portal_tensor_after_pop(*_):
        if is_checkpointing():
            assert portal_tensor_life_is(1)
        elif is_recomputing():
            assert portal_tensor_life_is(0)
        else:
            assert portal_tensor_life_is(0)

    class NoPortalTensorAtBackward(nn.Module):
        class F(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.skip_tracker = current_skip_tracker()
                return input.detach()

            @staticmethod
            def backward(ctx, grad):
                assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
                return grad

        def forward(self, input):
            return self.F.apply(input)

    model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
    model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint)

    input = torch.rand(10, requires_grad=True)

    if train:
        model.train()
        output = model(input)
        output.norm().backward()
    else:
        model.eval()
        with torch.no_grad():
            model(input)