def run_test_column_parallel_linear(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing ColumnParallelLinear with model parallel size: {}".format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = mpu.get_model_parallel_rank() my_dLdA = torch.split(dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): set_random_seed(seed) identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() logits = identity() logits_parallel = scatter_to_model_parallel_region(logits) target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() loss.backward() return loss, identity.weight.grad
def simple_linears(pipeline_style): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) def zero_grad(parameters): for p in parameters: p.grad = None set_random_seed(12345) inputs = torch.rand(8, 1) model = nn.Sequential( nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1), ) # Without MultiProcessPipe outputs = model(inputs) loss = outputs.mean() loss.backward() grad_without_pipe = [ sum_grad([*model[0].parameters(), *model[1].parameters()]), sum_grad([*model[2].parameters(), *model[3].parameters()]), ] ref_without_pipe = [p.grad for p in model.parameters()] zero_grad(model.parameters()) # With MultiProcessPipe model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4) outputs = model(inputs) if model.group.rank() == 1: loss = outputs.mean() loss.backward() grad_with_pipe = sum_grad( model.pipeline.partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) grad_with_pipe = sum_grad( model.pipeline.partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) torch.distributed.barrier()
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): set_random_seed(seed) identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() logits = identity() target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none").view_as(target).mean() loss.backward() return loss, identity.weight.grad
def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing parallel embedding with model parallel size {} ...". format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(" error in loss (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(" error in loss (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, hidden_size // model_parallel_size, 1)[mpu.get_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(" error in grad (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, vocab_size // model_parallel_size, 0)[mpu.get_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(" error in grad (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False): pipe_world_size = 2 if world_size == 1: return if not skip_dist_init: dist_init(rank, world_size, filename, filename_rpc) else: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29502" rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size) model_parallel_size = mpu.get_model_parallel_world_size() if torch.distributed.get_rank() == 0: print( "> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}" .format(model_parallel_size, pipe_world_size)) chunk_size = 4 seed = 12345 set_random_seed(seed) input_size_coeff = 3 input_size = input_size_coeff * model_parallel_size output_size_coeff = 7 output_size = output_size_coeff * model_parallel_size batch_size = 3 * chunk_size target = torch.rand((batch_size, input_size), requires_grad=True).cuda() print(f"target = {target}") identity = IdentityLayer2D(batch_size, input_size).cuda() pipeline_devices = mpu.get_pipeline_parallel_group() set_random_seed(seed) model = nn.Sequential( layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True, bias=False).cuda(), nn.ReLU(), layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(), ) set_random_seed(seed) reference = [ nn.Linear(input_size, output_size, bias=False).cuda(), nn.ReLU(), nn.Linear(output_size, input_size, bias=False).cuda(), ] print( f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}" ) print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}") reference[0].weight = Parameter( model[0].get_master_weight().clone()).cuda() reference[2].weight = Parameter( model[2].get_master_weight().clone()).cuda() reference = nn.Sequential(*reference) def grad_graph(depth, grad): result = depth * " " + str(grad) if grad: for x in grad.next_functions: result += "\n" + grad_graph(depth + 1, x[0]) return result def check_weights(x, y, key: str, index=None): for i in [2, 0]: if index is not None and i != index: continue left = x[i].get_master_weight() right = y[i].weight.data if not torch.allclose(left, right, atol=1.0e-6) or index is not None: print( f"check_weights {key}-{i}: left = {left}, \nright = {right}" ) if not torch.equal(left, right): print( f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}" ) assert torch.allclose(left, right, atol=1.0e-6) def dump_opt_params(opt): for i, group in enumerate(opt.param_groups): for j, p in enumerate(group["params"]): print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}") print( f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}" ) def forward_model(model_, target, step=False): optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() model_.zero_grad() output = model_(identity()) loss = nn.MSELoss() model_.zero_grad() if step: loss(output, target).backward() saved_weight_0 = model_[0].weight.data.clone() saved_weight_2 = model_[2].weight.data.clone() dump_opt_params(optimizer) optimizer.step() assert not torch.allclose( saved_weight_0, model_[0].weight.data, atol=1.0e-6) assert not torch.allclose( saved_weight_2, model_[2].weight.data, atol=1.0e-6) return output output = forward_model(model, target) reference_output = forward_model(reference, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 check_weights(model, reference, "before") saved_weight_0 = model[0].weight.data.clone() saved_weight_2 = model[2].weight.data.clone() output = forward_model(model, target, step=True) error = reference_output.sub(output).max() assert error < 1.0e-6 model[0].weight.data = saved_weight_0 model[2].weight.data = saved_weight_2 worker_map = { i: f"Test{i}" for i in range(torch.distributed.get_world_size()) } if pipe_world_size == 2: print("actually doing pipe stuff now") assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data) pipe_model = MultiProcessPipe( model, [2, 1], group=pipeline_devices, worker_map=worker_map, input_device=torch.cuda.current_device(), chunks=chunk_size, ).cuda() torch.distributed.barrier() pipe_rank = torch.distributed.get_rank( group=mpu.get_pipeline_parallel_group()) print(f"pipe rank is {pipe_rank}") if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) else: if not torch.equal(saved_weight_2, pipe_model[0].weight.data): print( f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}" ) assert torch.equal(saved_weight_2, pipe_model[0].weight.data) optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") else: assert torch.equal(saved_weight_2, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: check_weights(model, reference, "pre-pipe", index=2) else: check_weights(model, reference, "pre-pipe", index=0) pipe_output = pipe_model(identity()) print(f"exited pipe for {rank}") forward_model(reference, target, step=True) print(f"pipe_output {rank} = {pipe_output}") print(f"reference_output {rank} = {reference_output}") torch.distributed.barrier() if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = reference_output.sub(pipe_output.cuda()).max() if error >= 1.0e-6: print(f"error bad {error}") assert error < 1.0e-6 loss = nn.MSELoss() failed = False pipe_output.retain_grad() with torch.autograd.profiler.profile() as prof: try: loss(pipe_output, target).backward() except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) optimizer.step() print("calling check_weights on master") check_weights(model, reference, "pipe", index=2) print(f"waiting for barrier on master, pid={os.getpid()}") else: print(f"calling backwards on slave, pid={os.getpid()}") failed = False with torch.autograd.profiler.profile() as prof: try: pipe_model.back_helper(pipe_output) except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) print("calling step on slave") optimizer.step() print("calling check_weights on slave") check_weights(model, reference, "pipe", index=0) print("waiting for barrier on slave") pipe_model.zero_grad() torch.distributed.barrier() pipe_model.eval() pipe_output = pipe_model(identity()) updated_ref_output = forward_model(reference, target) if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = updated_ref_output.sub(pipe_output.cuda()).max() print( f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}" ) assert error < 1.0e-6 torch.distributed.barrier() print(f"finished waiting for barrier on, pid={os.getpid()}") print(f"really exited pipe for {rank}") rpc.shutdown() torch.distributed.destroy_process_group()
def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print( "> testing initialize_affine_weight with model parallel size: {}". format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print( " column parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(" row parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def reuse_lazy(): if False: # speed reused = LazyModule(lambda: nn.Linear(10, 10)) model = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] pipe = MultiProcessPipe(model, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe.eval() output = pipe(torch.rand(10)) print(f"output on {pipe.group.rank()}, {output}") torch.distributed.barrier() set_random_seed(1234) # test both foward reused = nn.Linear(10, 10) layers = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] model = nn.Sequential(*layers) model.eval() set_random_seed(1234) # ensure identical weights but no sharing between model and pipe reused = nn.Linear(10, 10) layers = [ reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU() ] pipe = MultiProcessPipe(layers, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe.eval() model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len( list(pipe.parameters())) else None inputs = torch.rand(10) if False: # speed model_out = model(inputs) pipe_out = pipe(inputs) torch.distributed.barrier() if pipe.final_stage: assert torch.equal(model_out, pipe_out) model.train() pipe.train() model_out = model(inputs) pipe_out = pipe(inputs) if pipe.final_stage: pipe_loss = pipe_out.mean() pipe_loss.backward() model_loss = model_out.mean() model_loss.backward() model_optimizer.step() if pipe_optimizer: pipe_optimizer.step() model.eval() pipe.eval() model_out = model(inputs) pipe_out = pipe(inputs) print(f"before barrier on {torch.distributed.get_rank()}") torch.distributed.barrier() print(f"after barrier on {torch.distributed.get_rank()}") if pipe.final_stage: assert torch.equal(model_out, pipe_out)