def parallel_randoms(pipeline_style): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): x = F.dropout(x, p=0.001) return x model = nn.Sequential(Dropouts(), Dropouts()) x = torch.rand(10, 10, requires_grad=True).cuda() x.retain_grad() model = Pipe( model, [1, 1], style=pipeline_style, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), chunks=10, checkpoint="always", ).cuda() y = model(x) tensor_list = [torch.empty_like(x) for _ in range(2)] if model.group.rank() == 1: y.norm().backward() torch.distributed.barrier() tensor_list[model.group.rank()] = y torch.distributed.all_gather(tensor_list, y, group=model.group) assert tensor_list[0].to(torch.bool).tolist() == tensor_list[1].to( torch.bool).tolist() else: model.back_helper(y) torch.distributed.barrier() tensor_list[model.group.rank()] = x.grad torch.distributed.all_gather(tensor_list, x.grad, group=model.group)
def checkpoint_non_float_input(): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) class JoinNonFloat(nn.Module): def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = Pipe( model, balance=[1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False, ) input = torch.rand(1, requires_grad=True) output = model(input) if model.group.rank() == 1: # with torch.autograd.detect_anomaly(): output.backward() else: model.back_helper(output)
def none_skip(pipeline_style): if pipeline_style == Pipe.AsyncSchedule: pytest.skip("Skip tensors NYI for AsyncSchedule") @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): yield stash("none", None) return input @skippable(pop=["none"]) class Pop(nn.Module): def forward(self, input): none = yield pop("none") assert none is None return input model = nn.Sequential(Stash(), Pop()) model = Pipe( model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5, ).cuda() input = torch.rand(10, requires_grad=True).cuda() input.retain_grad() output = model(input) def assert_grad_fn_is_not_portal(grad_fn, visited=set()): if grad_fn in visited or grad_fn is None: return assert not isinstance(grad_fn, PortalBlue._backward_cls) assert not isinstance(grad_fn, PortalCopy._backward_cls) assert not isinstance(grad_fn, PortalOrange._backward_cls) visited.add(grad_fn) for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) if model.group.rank() == 1: assert_grad_fn_is_not_portal(output.grad_fn) output.sum().backward() else: model.back_helper(output) assert input.grad.mean().item() == 1
def tuple_wait(cuda_sleep, pipeline_style): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized # properly to the copy stream. class Sleep(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.detach() @staticmethod def backward(ctx, grad): with torch.cuda.device(grad.device): cuda_sleep(0.05) return grad class Layer1(nn.Module): def forward(self, pair): a, b = pair return a * 1, b * 2, b * 3 class Layer2(nn.Module): def forward(self, triple): a, b, c = triple b = Sleep.apply(b) return a + b + c model = nn.Sequential(Layer1(), Layer2()) model = Pipe( model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=32, checkpoint="never", ).cuda() a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) if model.group.rank() == 1: y.norm().backward() else: model.back_helper(y) if model.group.rank() == 0: assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False): pipe_world_size = 2 if world_size == 1: return if not skip_dist_init: dist_init(rank, world_size, filename, filename_rpc) else: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29502" rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size) model_parallel_size = mpu.get_model_parallel_world_size() if torch.distributed.get_rank() == 0: print( "> testing Sequential + Pipe with model parallel size: {}, pipe: {}" .format(model_parallel_size, pipe_world_size)) chunk_size = 4 seed = 12345 set_random_seed(seed) input_size_coeff = 3 input_size = input_size_coeff * model_parallel_size output_size_coeff = 7 output_size = output_size_coeff * model_parallel_size batch_size = 3 * chunk_size target = torch.rand((batch_size, input_size), requires_grad=True).cuda() print(f"target = {target}") identity = IdentityLayer2D(batch_size, input_size).cuda() pipeline_devices = mpu.get_pipeline_parallel_group() set_random_seed(seed) model = nn.Sequential( layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True, bias=False).cuda(), nn.ReLU(), layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(), ) set_random_seed(seed) reference = [ nn.Linear(input_size, output_size, bias=False).cuda(), nn.ReLU(), nn.Linear(output_size, input_size, bias=False).cuda(), ] print( f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}" ) print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}") reference[0].weight = Parameter( model[0].get_master_weight().clone()).cuda() reference[2].weight = Parameter( model[2].get_master_weight().clone()).cuda() reference = nn.Sequential(*reference) def grad_graph(depth, grad): result = depth * " " + str(grad) if grad: for x in grad.next_functions: result += "\n" + grad_graph(depth + 1, x[0]) return result def check_weights(x, y, key: str, index=None): for i in [2, 0]: if index is not None and i != index: continue left = x[i].get_master_weight() right = y[i].weight.data if not torch.allclose(left, right, atol=1.0e-6) or index is not None: print( f"check_weights {key}-{i}: left = {left}, \nright = {right}" ) if not torch.equal(left, right): print( f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}" ) assert torch.allclose(left, right, atol=1.0e-6) def dump_opt_params(opt): for i, group in enumerate(opt.param_groups): for j, p in enumerate(group["params"]): print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}") print( f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}" ) def forward_model(model_, target, step=False): optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() model_.zero_grad() output = model_(identity()) loss = nn.MSELoss() model_.zero_grad() if step: loss(output, target).backward() saved_weight_0 = model_[0].weight.data.clone() saved_weight_2 = model_[2].weight.data.clone() dump_opt_params(optimizer) optimizer.step() assert not torch.allclose( saved_weight_0, model_[0].weight.data, atol=1.0e-6) assert not torch.allclose( saved_weight_2, model_[2].weight.data, atol=1.0e-6) return output output = forward_model(model, target) reference_output = forward_model(reference, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 check_weights(model, reference, "before") saved_weight_0 = model[0].weight.data.clone() saved_weight_2 = model[2].weight.data.clone() output = forward_model(model, target, step=True) error = reference_output.sub(output).max() assert error < 1.0e-6 model[0].weight.data = saved_weight_0 model[2].weight.data = saved_weight_2 worker_map = { i: f"Test{i}" for i in range(torch.distributed.get_world_size()) } style = Pipe.MultiProcess # Pipe.AsyncSchedule if pipe_world_size == 2: print(f"actually doing pipe stuff now") assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data) pipe_model = Pipe( model, [2, 1], style=style, group=pipeline_devices, worker_map=worker_map, input_device=torch.cuda.current_device(), chunks=chunk_size, pipelined_backward=True, ).cuda() torch.distributed.barrier() pipe_rank = torch.distributed.get_rank( group=mpu.get_pipeline_parallel_group()) print(f"pipe rank is {pipe_rank}") if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) else: if not torch.equal(saved_weight_2, pipe_model[0].weight.data): print( f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}" ) assert torch.equal(saved_weight_2, pipe_model[0].weight.data) optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") else: assert torch.equal(saved_weight_2, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: check_weights(model, reference, "pre-pipe", index=2) else: check_weights(model, reference, "pre-pipe", index=0) pipe_output = pipe_model(identity()) print(f"exited pipe for {rank}") forward_model(reference, target, step=True) print(f"pipe_output {rank} = {pipe_output}") print(f"reference_output {rank} = {reference_output}") torch.distributed.barrier() if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = reference_output.sub(pipe_output.cuda()).max() if error >= 1.0e-6: print(f"error bad {error}") assert error < 1.0e-6 loss = nn.MSELoss() failed = False pipe_output.retain_grad() with torch.autograd.profiler.profile() as prof: try: loss(pipe_output, target).backward() except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) optimizer.step() print(f"calling check_weights on master") check_weights(model, reference, "pipe", index=2) print(f"waiting for barrier on master, pid={os.getpid()}") else: print(f"calling backwards on slave, pid={os.getpid()}") failed = False with torch.autograd.profiler.profile() as prof: try: if style == Pipe.MultiProcess: pipe_model.back_helper(pipe_output) except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) print(f"calling step on slave") optimizer.step() print(f"calling check_weights on slave") check_weights(model, reference, "pipe", index=0) print(f"waiting for barrier on slave") pipe_model.zero_grad() torch.distributed.barrier() pipe_model.eval() pipe_output = pipe_model(identity()) updated_ref_output = forward_model(reference, target) if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = updated_ref_output.sub(pipe_output.cuda()).max() print( f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}" ) assert error < 1.0e-6 torch.distributed.barrier() print(f"finished waiting for barrier on, pid={os.getpid()}") print(f"really exited pipe for {rank}") rpc.shutdown() torch.distributed.destroy_process_group()
def x1to3(balance, checkpoint, pipeline_style): torch.manual_seed(0) if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1: print(f"skipping yarg") pytest.skip("Skip tensors NYI for AsyncSchedule") @skippable(stash=["1to3"]) class Layer1(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): yield stash("1to3", input) output = self.conv(input) return output class Layer2(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): output = self.conv(input) return output @skippable(pop=["1to3"]) class Layer3(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): skip_1to3 = yield pop("1to3") output = self.conv(input) + skip_1to3 return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = Pipe( model, balance, chunks=3, checkpoint=checkpoint, input_device=torch.cuda.current_device(), style=pipeline_style, worker_map=get_worker_map(), pipelined_backward=False, ).cuda() input = torch.rand(30, 3, 224, 224, requires_grad=True).cuda() input.retain_grad() output = model(input) if model.group.rank() == len(balance) - 1: loss = output.mean() loss.backward() elif model.group.rank() < len(balance) - 1: model.back_helper(output) if model.group.rank() == len(balance) - 1: # TODO(tom) the single-process test uses 2e-1 but for some reason # mutli-process is more noisy, need to investigate why assert torch.allclose(output.norm(), torch.tensor(1039.0).cuda(), atol=4e-1) if model.group.rank() == 0: assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053).cuda()) torch.distributed.barrier()
def delete_portal_tensor(train, checkpoint, pipeline_style): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function # +----------+ +------------+ # # With checkpointing: # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # +----------+ +------------+ +------------+ +----------+ if pipeline_style == Pipe.AsyncSchedule: pytest.skip("Skip tensors NYI for AsyncSchedule") def portal_tensor_life_is(tensor_life, skip_tracker=None): if skip_tracker is None: skip_tracker = current_skip_tracker() # Get the current portal. portal = list(skip_tracker.portals.values())[0] if tensor_life == 0: return portal.tensor_life == 0 and portal.tensor is None else: return portal.tensor_life == tensor_life and portal.tensor is not None # Check the portal tensor after 'Stash'. stash_ = Stash() @stash_.register_forward_hook def check_portal_tensor_after_stash(*_): if is_checkpointing(): assert portal_tensor_life_is(2) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(1) pop_ = Pop() @pop_.register_forward_hook def check_portal_tensor_after_pop(*_): if is_checkpointing(): assert portal_tensor_life_is(1) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(0) class NoPortalTensorAtBackward(nn.Module): class F(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.skip_tracker = current_skip_tracker() return input.detach() @staticmethod def backward(ctx, grad): assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) return grad def forward(self, input): return self.F.apply(input) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = Pipe( model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, ) input = torch.rand(10, requires_grad=True) if train: model.train() output = model(input) if model.group.rank() == 1: output.norm().backward() else: model.back_helper(output) else: model.eval() with torch.no_grad(): model(input) torch.distributed.barrier()