def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size): dist_init(rank, model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing model parallel cuda manual seed with size {} ...". format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with get_cuda_rng_tracker().fork(): assert torch.cuda.initial_seed() == (12345 + 2718 + mpu.get_model_parallel_rank()) # Reset the tracker get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_column_parallel_linear(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing ColumnParallelLinear with model parallel size: {}".format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = mpu.get_model_parallel_rank() my_dLdA = torch.split(dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def rpc_megatron_reuse(): from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel def make_model_simple(): return [ layers.ColumnParallelLinear(10, 10), nn.ReLU(), layers.RowParallelLinear(10, 10), nn.ReLU(), layers.ColumnParallelLinear(10, 10), nn.ReLU(), layers.RowParallelLinear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), ] def make_model_with_reuse(): column = layers.ColumnParallelLinear(10, 10) row = layers.RowParallelLinear(10, 10) return [ column, nn.ReLU(), row, nn.ReLU(), column, nn.ReLU(), row, nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), ] destroy_model_parallel() torch.distributed.destroy_process_group() torch.distributed.init_process_group("gloo", rank=int(os.environ["RANK"]), world_size=int( os.environ["WORLD_SIZE"])) initialize_model_parallel(2, 3, model_parallel_backend="nccl", pipeline_backend="mpi") init_rpc() if get_pipeline_parallel_group().rank() != 0: rpc.shutdown() torch.distributed.barrier() return check_pipe_against_reference([4, 4, 2], make_model_simple, "always") check_pipe_against_reference([4, 2, 2], make_model_with_reuse) rpc.shutdown() torch.distributed.barrier()
def mpi_pipe(): mpu.destroy_model_parallel() _, tempfile_init = tempfile.mkstemp() _, tempfile_rpc_init = tempfile.mkstemp() run_test_pipe( torch.distributed.get_rank(), torch.distributed.get_world_size(), tempfile_init, tempfile_rpc_init, skip_dist_init=True, )
def pipelined_backward(): model = nn.Sequential(nn.ReLU(), nn.ReLU()) destroy_model_parallel() initialize_model_parallel(1, 4) pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) assert pipe.pipelined_backward is False destroy_model_parallel() initialize_model_parallel(2, 2) pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) assert pipe.pipelined_backward is True
def pipelined_backward(pipe_class): model = nn.Sequential(nn.ReLU(), nn.ReLU()) destroy_model_parallel() initialize_model_parallel(1, 4) pipe = pipe_class(model, [1, 1], worker_map=get_worker_map()) assert pipe.pipelined_backward is False destroy_model_parallel() initialize_model_parallel(2, 2) pipe = pipe_class(model, [1, 1], worker_map=get_worker_map()) assert pipe.pipelined_backward is True
def run_test_cross_entropy(rank, model_parallel_size): dist_init(rank, model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing cross entropy with model parallel size {} ...".format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 13 seq_length = 17 vocab_size_per_partition = 11 logits_scale = 1000.0 vocab_size = vocab_size_per_partition * model_parallel_size seed = 1234 loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) error = loss_torch.sub_(loss_mpu).abs().max() print(" max error in loss on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = grad_torch.sub_(grad_mpu).abs().max() print(" max error in grad on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_): dist_init(rank, model_parallel_size_) if torch.distributed.get_rank() == 0: print("> testing get_model_parallel_src_rank with size {} ...".format( model_parallel_size_)) model_parallel_size = min(model_parallel_size_, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size) assert mpu.model_parallel_is_initialized() # Checks src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() assert mpu.get_model_parallel_src_rank() == src_rank # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_initialize_model_parallel(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing initialize_model_parallel with size {} ...".format( model_parallel_size)) model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size_) assert mpu.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = model_parallel_size_ rank = torch.distributed.get_rank() % model_parallel_size_ assert world_size == mpu.get_model_parallel_world_size() assert rank == mpu.get_model_parallel_rank() check(mpu.get_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size() // model_parallel_size_ rank = torch.distributed.get_rank() // model_parallel_size assert world_size == mpu.get_data_parallel_world_size() assert rank == mpu.get_data_parallel_rank() check(mpu.get_data_parallel_group(), world_size, rank) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size): dist_init(rank, model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing cuda rng tracker with size {} ...".format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() seed_1 = 1234 seed_2 = 4321 size = [12, 21] tensor = torch.cuda.FloatTensor(size) # Set to seed_1 and generate two tensors. torch.cuda.manual_seed(seed_1) torch.randn(size, out=tensor) target_11 = tensor.clone() torch.randn(size, out=tensor) target_12 = tensor.clone() # Set to seed_2 and generate two tensors. torch.cuda.manual_seed(seed_2) torch.randn(size, out=tensor) target_21 = tensor.clone() torch.randn(size, out=tensor) target_22 = tensor.clone() # Now if we interleave seed_1 and seed_2, # we should still get the same tensors torch.cuda.manual_seed(seed_1) get_cuda_rng_tracker().add("test", seed_2) torch.randn(size, out=tensor) result_11 = tensor.clone() with get_cuda_rng_tracker().fork("test"): torch.randn(size, out=tensor) result_21 = tensor.clone() torch.randn(size, out=tensor) result_12 = tensor.clone() with get_cuda_rng_tracker().fork("test"): torch.randn(size, out=tensor) result_22 = tensor.clone() diff = result_11.sub(result_21).abs().max() diff = min(diff, result_12.sub(result_22).abs().max()) print( " max diff in generated tensors (should be non-zero) on global rank {}: {}" .format(torch.distributed.get_rank(), diff)) assert diff > 1.0e-6 error = max( result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max()) error = max(error, result_21.sub(target_21).abs().max()) error = max(error, result_22.sub(target_22).abs().max()) print( " max error in generated tensors (should be zero) on global rank {}: {}" .format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset the tracker get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_set_cuda_rng_state(rank, model_parallel_size): dist_init(rank, model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing set_rng_state with size {} ...".format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() size = 123 seed = 1234 torch.cuda.manual_seed(1234) tensor = torch.cuda.FloatTensor(size) # Get the state rng_state = torch.cuda.get_rng_state() rng_state_copy = rng_state.clone() # Do some stuff. for _ in range(5): torch.randn(size, out=tensor) result_1 = tensor.clone() assert rng_state.sub(rng_state_copy).max() == 0 assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 # State should be different. new_rng_state = torch.cuda.get_rng_state() max_diff = new_rng_state.sub(rng_state).max() print( " max diff in rng state (should be non-zero) on global rank {}: {}". format(torch.distributed.get_rank(), max_diff)) assert max_diff > 0 # Reset the rng state and do the same stuff. random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) result_2 = tensor.clone() # Results should be the same error = result_2.sub(result_1).abs().max() print( " max error in generated tensors (should be zero) on global rank {}: {}" .format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Input state should have remained intact. error = rng_state.sub(rng_state_copy).max() print(" max error in rng state (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error == 0 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing parallel embedding with model parallel size {} ...". format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(" error in loss (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(" error in loss (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, hidden_size // model_parallel_size, 1)[mpu.get_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(" error in grad (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, vocab_size // model_parallel_size, 0)[mpu.get_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(" error in grad (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print( "> testing initialize_affine_weight with model parallel size: {}". format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print( " column parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(" row parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def mpi_pipe(): mpu.destroy_model_parallel() run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), skip_dist_init=True)