def checkpoint_eval(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) model = MultiProcessPipe( model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): if grad_fn is None: return False if grad_fn.__class__.__name__ == name: return True for next_grad_fn, _ in grad_fn.next_functions: if find_grad_fn(next_grad_fn, name): return True return False model.train() train_output = model(input) assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
def reuse_lazy(): if False: # speed reused = LazyModule(lambda: nn.Linear(10, 10)) model = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] pipe = MultiProcessPipe(model, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe.eval() output = pipe(torch.rand(10)) print(f"output on {pipe.group.rank()}, {output}") torch.distributed.barrier() set_random_seed(1234) # test both foward reused = nn.Linear(10, 10) layers = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] model = nn.Sequential(*layers) model.eval() set_random_seed(1234) # ensure identical weights but no sharing between model and pipe reused = nn.Linear(10, 10) layers = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] pipe = MultiProcessPipe(layers, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe.eval() model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len( list(pipe.parameters())) else None inputs = torch.rand(10) if False: # speed model_out = model(inputs) pipe_out = pipe(inputs) torch.distributed.barrier() if pipe.final_stage: assert torch.equal(model_out, pipe_out) model.train() pipe.train() model_out = model(inputs) pipe_out = pipe(inputs) if pipe.final_stage: pipe_loss = pipe_out.mean() pipe_loss.backward() model_loss = model_out.mean() model_loss.backward() model_optimizer.step() if pipe_optimizer: pipe_optimizer.step() model.eval() pipe.eval() model_out = model(inputs) pipe_out = pipe(inputs) print(f"before barrier on {torch.distributed.get_rank()}") torch.distributed.barrier() print(f"after barrier on {torch.distributed.get_rank()}") if pipe.final_stage: assert torch.equal(model_out, pipe_out)
def delete_portal_tensor(train, checkpoint, pipeline_style): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function # +----------+ +------------+ # # With checkpointing: # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # +----------+ +------------+ +------------+ +----------+ if pipeline_style == MultiProcessPipe.AsyncSchedule: pytest.skip("Skip tensors NYI for AsyncSchedule") def portal_tensor_life_is(tensor_life, skip_tracker=None): if skip_tracker is None: skip_tracker = current_skip_tracker() # Get the current portal. portal = list(skip_tracker.portals.values())[0] if tensor_life == 0: return portal.tensor_life == 0 and portal.tensor is None else: return portal.tensor_life == tensor_life and portal.tensor is not None # Check the portal tensor after 'Stash'. stash_ = Stash() @stash_.register_forward_hook def check_portal_tensor_after_stash(*_): if is_checkpointing(): assert portal_tensor_life_is(2) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(1) pop_ = Pop() @pop_.register_forward_hook def check_portal_tensor_after_pop(*_): if is_checkpointing(): assert portal_tensor_life_is(1) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(0) class NoPortalTensorAtBackward(nn.Module): class F(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.skip_tracker = current_skip_tracker() return input.detach() @staticmethod def backward(ctx, grad): assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) return grad def forward(self, input): return self.F.apply(input) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = MultiProcessPipe( model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, ) input = torch.rand(10, requires_grad=True) if train: model.train() output = model(input) if model.group.rank() == 1: output.norm().backward() else: model.back_helper(output) else: model.eval() with torch.no_grad(): model(input) torch.distributed.barrier()