def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, dropout_prob, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() attention_layer = parallel_state.BertParallelSelfAttention( hidden_size, num_att_heads, dropout_prob).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = attention_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ attention_layer, identity_layer
def test_row_parallel_linear(self) -> None: for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size set_random_seed(self.SEED) linear_layer = layers.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, ).cuda() loss_weight = torch.randn( (self.BATCH_SIZE, output_size)).cuda() # Forward and backward input_tensor = torch.randn(self.BATCH_SIZE, input_size, requires_grad=True).cuda() input_tensor.retain_grad() output, _ = linear_layer(input_tensor) loss = torch.mul(output, loss_weight).sum() loss.backward() self.assertIsNotNone(input_tensor.grad) with torch.no_grad(): dldy = loss_weight.clone() x = input_tensor.clone() a = linear_layer.master_weight.cuda() dlda = torch.matmul(dldy.t(), x) dldb = torch.matmul( torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy).view(-1) dldx = torch.matmul(dldy, a) with torch.no_grad(): curr_dlda = torch.split( dlda, self.INPUT_SIZE_COEFF, dim=1 )[parallel_state.get_tensor_model_parallel_rank()].clone() self.assertEqual(linear_layer.weight.grad, curr_dlda) self.assertEqual(input_tensor.grad, dldx) self.assertEqual(linear_layer.bias.grad, dldb) parallel_state.destroy_model_parallel()
def _affine_weight_init_test_impl(self, init_device: str, is_column_parallel: bool) -> None: dim = int(not is_column_parallel) for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size weight_shape = ((self.OUTPUT_SIZE_COEFF, input_size) if is_column_parallel else (output_size, self.INPUT_SIZE_COEFF)) weight = torch.empty(weight_shape) set_random_seed(self.SEED) sharding_dim_size = (self.OUTPUT_SIZE_COEFF if is_column_parallel else self.INPUT_SIZE_COEFF) if init_device == "cpu": layers._initialize_affine_weight_cpu( weight, output_size, input_size, sharding_dim_size, dim, nn.init.normal_, params_dtype=torch.float32, ) else: layers._initialize_affine_weight_gpu( weight, torch.nn.init.normal_, dim) # Target set_random_seed(self.SEED) if init_device == "cpu": main_weight = torch.empty(output_size, input_size) nn.init.normal_(main_weight) curr_weight = torch.split( main_weight, sharding_dim_size, dim=dim)[ parallel_state.get_tensor_model_parallel_rank()] else: curr_weight = torch.empty(*weight_shape) nn.init.normal_(curr_weight) self.assertEqual(curr_weight, weight) parallel_state.destroy_model_parallel()
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): set_random_seed(seed) identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() logits = identity() target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), target.view(-1), reduction='none').view_as(target).mean() loss.backward() return loss, identity.weight.grad
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): set_random_seed(seed) identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() logits = identity() logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region( logits) target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) logits_parallel_ = logits_parallel.clone().detach() loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() loss.backward() # check for mutation assert torch.equal(logits_parallel_, logits_parallel) return loss, identity.weight.grad
def test_column_parallel_linear_with_async_allreduce_custom_amp( tensor_model_parallel_size): dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else ( torch.half, ) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 for dtype in dtypes: # Network identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype) linear_layer = layers.ColumnParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args(). use_cpu_initialization, ).to(device="cuda", dtype=dtype) # Forward loss_weight = torch.randn([batch_size, output_size]).cuda() output, _ = linear_layer(identity_layer()) loss = torch.mul(output, loss_weight).sum() loss.backward() torch.distributed.barrier() assert output.dtype == dtype # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads intermediate_size = 4 * hidden_size # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() transformer_layer = parallel_state.BertParallelTransformerLayer( hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, torch.nn.functional.relu, 1.0e-5).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = transformer_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ transformer_layer, identity_layer
def test_row_parallel_linear(tensor_model_parallel_size): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing RowParallelLinear with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = layers.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args().use_cpu_initialization, ).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output, _ = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = parallel_state.get_tensor_model_parallel_rank() my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdA on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(' error in dLdb on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdX on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_parallel_embedding(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing parallel embedding with model parallel size {} ...'. format(tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(' error in loss (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(' error in loss (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split( embedding_original.weight.grad, hidden_size // tensor_model_parallel_size, 1)[parallel_state.get_tensor_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(' error in grad (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split( embedding_original.weight.grad, vocab_size // tensor_model_parallel_size, 0)[parallel_state.get_tensor_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(' error in grad (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_column_parallel_linear(tensor_model_parallel_size): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing ColumnParallelLinear with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 hidden_size = 9 # Network gradient_accumulation_fusion = True identity_layer = IdentityLayer3D(batch_size, hidden_size, input_size).cuda() linear_layer = layers.ColumnParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args().use_cpu_initialization, gradient_accumulation_fusion=gradient_accumulation_fusion, ).cuda() with torch.no_grad(): linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight) loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda() # Forward input_ = identity_layer() output, _ = linear_layer(input_) assert list(output.shape) == [batch_size, hidden_size, output_size] loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # TODO (mkozuki): Fix the following commented out lines # as `gradient_accumulation_fusion` only takes 3D tensors. # Values. # dLdY = loss_weight # (7, 9, 17) # X = identity_layer.weight # (7, 9, 13) # A = linear_layer.master_weight.cuda() # (17, 13) # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}") # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13)) # print(f"dLdA.shape = {dLdA.shape}") # ones = torch.ones(batch_size, hidden_size, 1).cuda() # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}") # dLdb = torch.matmul(ones, dLdY).view(-1) # dLdX = torch.matmul(dLdY, A) # rank = parallel_state.get_tensor_model_parallel_rank() # my_dLdA = torch.split(dLdA, output_size_coeff, # dim=0)[rank].contiguous().clone() # error = my_dLdA.sub(linear_layer.weight.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdA on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # my_dLdb = torch.split(dLdb, output_size_coeff, # dim=0)[rank].contiguous().clone() # error = my_dLdb.sub(linear_layer.bias.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdb on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # error = dLdX.sub(identity_layer.weight.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdX on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing initialize_affine_weight with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) if device == 'cpu': layers._initialize_affine_weight_cpu( weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_, params_dtype=global_vars.get_args().params_dtype, ) else: layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = parallel_state.get_tensor_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' column parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) if device == 'cpu': layers._initialize_affine_weight_cpu( weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_, params_dtype=global_vars.get_args().params_dtype) else: layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = parallel_state.get_tensor_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' row parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_parallel_embedding(self) -> None: for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, ) set_random_seed(self.SEED + 1) input_tensor = torch.randint( 0, self.VOCAB_SIZE, ( self.BATCH_SIZE, self.SEQUENCE_LENGTH, ), device="cuda", ) loss_weight = torch.randn( ( self.BATCH_SIZE, self.SEQUENCE_LENGTH, self.HIDDEN_SIZE, ), device="cuda", ) set_random_seed(self.SEED) embedding_torch = nn.Embedding( self.VOCAB_SIZE, self.HIDDEN_SIZE, ).cuda() output_torch = embedding_torch(input_tensor) loss_torch = torch.mul(output_torch, loss_weight).sum() loss_torch.backward() # N.B. (mkozuki): With affine weight initialization on GPU, # it's super difficult to keep the consistency with nn.Embedding. # Thus, turning on `use_cpu_initialization`. set_random_seed(self.SEED) embedding_vocab_parallel = layers.VocabParallelEmbedding( self.VOCAB_SIZE, self.HIDDEN_SIZE, init_method=nn.init.normal_, use_cpu_initialization=True, ).cuda() output_vocab_parallel = embedding_vocab_parallel(input_tensor) loss_vocab_parallel = torch.mul(output_vocab_parallel, loss_weight).sum() loss_vocab_parallel.backward() self.assertEqual(output_torch, output_vocab_parallel) self.assertEqual(loss_torch, loss_vocab_parallel) splitted_weight_torch = torch.split( embedding_torch.weight.grad, self.VOCAB_SIZE // tensor_model_parallel_world_size, 0, )[parallel_state.get_tensor_model_parallel_rank()] self.assertEqual(splitted_weight_torch, embedding_vocab_parallel.weight.grad) parallel_state.destroy_model_parallel()
def _column_parallel_linear_test_impl( self, no_async_tensor_model_parallel_allreduce: bool, gradient_accumulation_fusion: bool, ): for tensor_model_parallel_world_size in range(1, self.world_size + 1): with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, ) feature_size_coeff = self.INPUT_SIZE_COEFF feature_size = feature_size_coeff * tensor_model_parallel_world_size hidden_size = feature_size set_random_seed(self.SEED) input_tensor = torch.randn( self.BATCH_SIZE, hidden_size, feature_size, device="cuda", requires_grad=True, ) input_tensor.retain_grad() loss_weight = torch.randn( ( self.BATCH_SIZE, hidden_size, feature_size, ), device="cuda", ) linear = layers.ColumnParallelLinear( feature_size, feature_size, bias=False, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, no_async_tensor_model_parallel_allreduce= no_async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ).cuda() if gradient_accumulation_fusion: with torch.no_grad(): linear.weight.main_grad = torch.randn_like( linear.weight) output, _ = linear(input_tensor) self.assertEqual( output.shape, ( self.BATCH_SIZE, hidden_size, feature_size, ), ) loss = torch.mul(output, loss_weight).sum() loss.backward() with torch.no_grad(): dldy = loss_weight.clone() x = input_tensor.clone() a = linear.master_weight.cuda().clone() dldx = torch.matmul(dldy, a) self.assertEqual(input_tensor.grad, dldx) # TODO(mkozuki): Cover the other cases. if (tensor_model_parallel_world_size == 1 and not gradient_accumulation_fusion): dlda = torch.matmul(torch.transpose(dldy, 1, 2), x).sum(dim=0) curr_dlda = torch.split( dlda, feature_size_coeff, dim=0)[parallel_state.get_tensor_model_parallel_rank()] self.assertEqual(linear.weight.grad, curr_dlda) parallel_state.destroy_model_parallel()