def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads intermediate_size = 4 * hidden_size # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() transformer_layer = parallel_state.BertParallelTransformerLayer( hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, torch.nn.functional.relu, 1.0e-5).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = transformer_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ transformer_layer, identity_layer
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing model parallel cuda manual seed with size {} ...'. format(tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) tensor_parallel.random.model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with tensor_parallel.random.get_cuda_rng_tracker().fork(): assert (torch.cuda.initial_seed() == 12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()) # Reset the tracker tensor_parallel.random.get_cuda_rng_tracker().reset() # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, dropout_prob, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() attention_layer = parallel_state.BertParallelSelfAttention( hidden_size, num_att_heads, dropout_prob).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = attention_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ attention_layer, identity_layer
def forward(ctx, vocab_parallel_logits, target): # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze( dim=-1) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size()[-1] rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits # Store softmax, target-mask and masked-target for backward pass. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) return loss
def split_tensor_into_1d_equal_chunks(tensor): """Break a tensor into equal 1D chunks.""" data = tensor.view(-1) partition_size = torch.numel( data) // parallel_state.get_tensor_model_parallel_world_size() start_index = partition_size * parallel_state.get_tensor_model_parallel_rank( ) end_index = start_index + partition_size return data[start_index:end_index]
def __init__( self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False, ): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim # Set the detauls for compatibility. self.padding_idx = None self.max_norm = None self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = get_tensor_model_parallel_world_size( ) # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. if use_cpu_initialization: self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)) _initialize_affine_weight_cpu( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, params_dtype=params_dtype, ) else: self.weight = Parameter( torch.empty( self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype, )) _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
def test__reduce(args, tensor_model_parallel_size): print("Testing reduction size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) assert torch.equal( mappings._reduce(torch.full((10, 10, 10, 10), (50))), torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size), ) parallel_state.destroy_model_parallel() print("Passed!")
def test_broadcast_data(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing broadcast_data with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) torch.manual_seed(1234 + parallel_state.get_data_parallel_rank()) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12], } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def _reduce(input_): """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ # All-reduce. torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) return input_
def test__gather(args, tensor_model_parallel_size): print("Testing gathering size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) assert torch.equal( mappings._gather( torch.tensor([parallel_state.get_tensor_model_parallel_rank()])), torch.tensor(list(range(tensor_model_parallel_size))), ) parallel_state.destroy_model_parallel() print("Passed!")
def gather_split_1d_tensor(tensor): """Opposite of above function, gather values from model parallel ranks.""" world_size = parallel_state.get_tensor_model_parallel_world_size() numel = torch.numel(tensor) numel_gathered = world_size * numel gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] torch.distributed.all_gather( chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) return gathered
def test__split(args, tensor_model_parallel_size): print("Testing splitting size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) listy = [] for i in range(tensor_model_parallel_size): listy.append(torch.randn(10, 1)) x = torch.cat(tuple(listy), 1) out = mappings._split(x) assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()]) parallel_state.destroy_model_parallel() print("Passed!")
def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" world_size = get_tensor_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Split along last dimension. input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. rank = get_tensor_model_parallel_rank() output = input_list[rank].contiguous() return output
def test_column_parallel_linear_with_async_allreduce_custom_amp( tensor_model_parallel_size): dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else ( torch.half, ) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 for dtype in dtypes: # Network identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype) linear_layer = layers.ColumnParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args(). use_cpu_initialization, ).to(device="cuda", dtype=dtype) # Forward loss_weight = torch.randn([batch_size, output_size]).cuda() output, _ = linear_layer(identity_layer()) loss = torch.mul(output, loss_weight).sum() loss.backward() torch.distributed.barrier() assert output.dtype == dtype # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def _initialize_affine_weight_cpu( weight, output_size, input_size, per_partition_size, partition_dim, init_method, stride=1, return_master_weight=False, *, params_dtype=torch.float32, ): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk.""" set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) init_method(master_weight) master_weight = master_weight.to(dtype=params_dtype) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): torch.cat(my_weight_list, dim=partition_dim, out=weight) if return_master_weight: return master_weight return None
def test_initialize_model_parallel(self) -> None: self.assertFalse(parallel_state.model_parallel_is_initialized()) for tensor_model_parallel_world_size in range(1, self.world_size + 1): with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): if self.world_size % tensor_model_parallel_world_size: continue pipeline_model_parallel_world_size = ( self.world_size // tensor_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, ) self.assertEqual( tensor_model_parallel_world_size, parallel_state.get_tensor_model_parallel_world_size(), ) expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size) self.assertEqual( expected_tensor_model_parallel_rank, parallel_state.get_tensor_model_parallel_rank(), ) expected_tensor_model_parallel_src_rank = ( self.rank // tensor_model_parallel_world_size ) * tensor_model_parallel_world_size self.assertEqual( expected_tensor_model_parallel_src_rank, parallel_state.get_tensor_model_parallel_src_rank(), ) parallel_state.destroy_model_parallel() self.assertFalse( parallel_state.model_parallel_is_initialized())
def test_cross_entropy(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing cross entropy with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) batch_size = 13 seq_length = 17 vocab_size_per_partition = 11 logits_scale = 1000.0 vocab_size = vocab_size_per_partition * tensor_model_parallel_size seed = 1234 loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) error = loss_torch.sub_(loss_mpu).abs().max() print(' max error in loss on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = grad_torch.sub_(grad_mpu).abs().max() print(' max error in grad on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def _gather(input_): """Gather tensors and concatinate along the last dimension.""" world_size = get_tensor_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Size and dimension. last_dim = input_.dim() - 1 rank = get_tensor_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() return output
def test_initialize_model_parallel(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( tensor_model_parallel_size)) tensor_model_parallel_size_ = min( tensor_model_parallel_size, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size_) assert parallel_state.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = tensor_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_ assert world_size == parallel_state.get_tensor_model_parallel_world_size() assert rank == parallel_state.get_tensor_model_parallel_rank() check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size( ) // tensor_model_parallel_size_ rank = torch.distributed.get_rank() // tensor_model_parallel_size assert world_size == parallel_state.get_data_parallel_world_size() assert rank == parallel_state.get_data_parallel_rank() check(parallel_state.get_data_parallel_group(), world_size, rank) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing cuda rng tracker with size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed_1 = 1234 seed_2 = 4321 size = [12, 21] tensor = torch.cuda.FloatTensor(size) # Set to seed_1 and generate two tensors. torch.cuda.manual_seed(seed_1) torch.randn(size, out=tensor) target_11 = tensor.clone() torch.randn(size, out=tensor) target_12 = tensor.clone() # Set to seed_2 and generate two tensors. torch.cuda.manual_seed(seed_2) torch.randn(size, out=tensor) target_21 = tensor.clone() torch.randn(size, out=tensor) target_22 = tensor.clone() # Now if we interleave seed_1 and seed_2, # we should still get the same tensors torch.cuda.manual_seed(seed_1) tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2) torch.randn(size, out=tensor) result_11 = tensor.clone() with tensor_parallel.random.get_cuda_rng_tracker().fork('test'): torch.randn(size, out=tensor) result_21 = tensor.clone() torch.randn(size, out=tensor) result_12 = tensor.clone() with tensor_parallel.random.get_cuda_rng_tracker().fork('test'): torch.randn(size, out=tensor) result_22 = tensor.clone() diff = result_11.sub(result_21).abs().max() diff = min(diff, result_12.sub(result_22).abs().max()) print(' max diff in generated tensors (should be non-zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) assert diff > 1.0e-6 error = max( result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max()) error = max(error, result_21.sub(target_21).abs().max()) error = max(error, result_22.sub(target_22).abs().max()) print(' max error in generated tensors (should be zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset the tracker tensor_parallel.random.get_cuda_rng_tracker().reset() # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_set_cuda_rng_state(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing set_rng_state with size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) size = 123 seed = 1234 torch.cuda.manual_seed(seed) tensor = torch.cuda.FloatTensor(size) # Get the state rng_state = torch.cuda.get_rng_state() rng_state_copy = rng_state.clone() # Do some stuff. for _ in range(5): torch.randn(size, out=tensor) result_1 = tensor.clone() assert rng_state.sub(rng_state_copy).max() == 0 assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 # State should be different. new_rng_state = torch.cuda.get_rng_state() max_diff = new_rng_state.sub(rng_state).max() print( ' max diff in rng state (should be non-zero) on global rank {}: {}'. format(torch.distributed.get_rank(), max_diff)) assert max_diff > 0 # Reset the rng state and do the same stuff. tensor_parallel.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) tensor_parallel.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) result_2 = tensor.clone() # Results should be the same error = result_2.sub(result_1).abs().max() print(' max error in generated tensors (should be zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Input state should have remained intact. error = rng_state.sub(rng_state_copy).max() print(' max error in rng state (should be zero) on global rank {}: {}'. format(torch.distributed.get_rank(), error)) assert error == 0 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def _communicate( tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Optional[Shape] = None, override_scatter_gather_tensors_in_pipeline: bool = False, dtype_: torch.dtype = torch.float, *, scatter_gather_tensors_in_pipeline: bool = True, params_dtype: Optional[torch.dtype] = None, fp32_residual_connection: bool = False, ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]: """Base function for communication of tensors between stages. Args: tensor_send_next: tensor to send to next rank (no tensor sent if set to None). tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). recv_prev: boolean for whether tensor should be received from previous rank. recv_next: boolean for whether tensor should be received from next rank. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to override scatter gather tensors dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape Keyword args: scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on your model deliberately, pass this argument. fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. Returns: tuple containing - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. """ # Create placeholder tensors for receive in forward and backward directions if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` raise RuntimeError( "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`") if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),) else: tensor_chunk_shape = tensor_shape # NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32, # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general. # It might be possible if we restrict model architecture. # dtype = params_dtype or torch.float # if fp32_residual_connection: # dtype = torch.float # if dtype_ is not None: # dtype = dtype_ # requires_grad = False if dtype_ != torch.float32 or params_dtype is not None: if torch.distributed.get_rank() == 0: warnings.warn("Tensor P2P communications are executed in FP32") dtype = torch.float32 requires_grad = True if recv_prev: tensor_recv_prev = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) if recv_next: tensor_recv_next = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) # Split tensor into smaller chunks if using scatter-gather optimization. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next) if tensor_send_prev is not None: tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next) # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if recv_prev: tensor_recv_prev = ( gather_split_1d_tensor(tensor_recv_prev) .view(tensor_shape) .requires_grad_() ) if recv_next: tensor_recv_next = ( gather_split_1d_tensor(tensor_recv_next) .view(tensor_shape) .requires_grad_() ) return tensor_recv_prev, tensor_recv_next
def _communicate( tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Optional[Shape] = None, override_scatter_gather_tensors_in_pipeline: bool = False, dtype_: Optional[torch.dtype] = None, *, scatter_gather_tensors_in_pipeline: bool = True, params_dtype: Optional[torch.dtype] = None, fp32_residual_connection: bool = False, async_comm: bool = False, ) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]: """Base function for communication of tensors between stages. dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified, torch.float32 is used. See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159 for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``. Args: tensor_send_next: tensor to send to next rank (no tensor sent if set to None). tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). recv_prev: boolean for whether tensor should be received from previous rank. recv_next: boolean for whether tensor should be received from next rank. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to override scatter gather tensors dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape Keyword args: scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on your model deliberately, pass this argument. fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. Returns: tuple containing - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. """ # Create placeholder tensors for receive in forward and backward directions if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` raise RuntimeError( "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`" ) if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: tensor_chunk_shape = ( reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(), ) else: tensor_chunk_shape = tensor_shape # The dtype logic below is copied from NVIDIA/Megatron-LM repo: # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81 # NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to # use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`. # NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32, # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general. # It might be possible if we restrict model architecture. dtype = params_dtype or torch.float if fp32_residual_connection: dtype = torch.float requires_grad = True if dtype_ is not None: dtype = dtype_ requires_grad = False if recv_prev: tensor_recv_prev = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) if recv_next: tensor_recv_next = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) # Split tensor into smaller chunks if using scatter-gather optimization. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = split_tensor_into_1d_equal_chunks( tensor_send_next) if tensor_send_prev is not None: tensor_send_prev = split_tensor_into_1d_equal_chunks( tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops( tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm) if async_comm: tensor_recv_prev_waitfunc = None tensor_recv_next_waitfunc = None # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642) # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait if tensor_recv_prev_req is not None: def tensor_recv_prev_wait(): tensor_recv_prev_req.wait() torch.cuda.synchronize() tensor_recv_prev_waitfunc = tensor_recv_prev_wait if tensor_recv_next_req is not None: def tensor_recv_next_wait(): tensor_recv_next_req.wait() torch.cuda.synchronize() tensor_recv_next_waitfunc = tensor_recv_next_wait else: # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if not async_comm: if recv_prev: tensor_recv_prev = (gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_()) if recv_next: tensor_recv_next = (gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_()) else: def gather_recv_prev_wait(): tensor_recv_prev_req.wait() # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14 # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_prev).view( tensor_shape).requires_grad_()) def gather_recv_next_wait(): tensor_recv_next_req.wait() torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_next).view( tensor_shape).requires_grad_()) tensor_recv_prev_waitfunc = gather_recv_prev_wait tensor_recv_next_waitfunc = gather_recv_next_wait if async_comm: future_tensor_recv_prev = None future_tensor_recv_next = None if tensor_recv_prev is not None: future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc) if tensor_recv_next is not None: future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc) return future_tensor_recv_prev, future_tensor_recv_next return tensor_recv_prev, tensor_recv_next
def test_row_parallel_linear(tensor_model_parallel_size): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing RowParallelLinear with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = layers.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args().use_cpu_initialization, ).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output, _ = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = parallel_state.get_tensor_model_parallel_rank() my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdA on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(' error in dLdb on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdX on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_parallel_embedding(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing parallel embedding with model parallel size {} ...'. format(tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(' error in loss (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(' error in loss (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split( embedding_original.weight.grad, hidden_size // tensor_model_parallel_size, 1)[parallel_state.get_tensor_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(' error in grad (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split( embedding_original.weight.grad, vocab_size // tensor_model_parallel_size, 0)[parallel_state.get_tensor_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(' error in grad (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_column_parallel_linear(tensor_model_parallel_size): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing ColumnParallelLinear with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 hidden_size = 9 # Network gradient_accumulation_fusion = True identity_layer = IdentityLayer3D(batch_size, hidden_size, input_size).cuda() linear_layer = layers.ColumnParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args().use_cpu_initialization, gradient_accumulation_fusion=gradient_accumulation_fusion, ).cuda() with torch.no_grad(): linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight) loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda() # Forward input_ = identity_layer() output, _ = linear_layer(input_) assert list(output.shape) == [batch_size, hidden_size, output_size] loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # TODO (mkozuki): Fix the following commented out lines # as `gradient_accumulation_fusion` only takes 3D tensors. # Values. # dLdY = loss_weight # (7, 9, 17) # X = identity_layer.weight # (7, 9, 13) # A = linear_layer.master_weight.cuda() # (17, 13) # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}") # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13)) # print(f"dLdA.shape = {dLdA.shape}") # ones = torch.ones(batch_size, hidden_size, 1).cuda() # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}") # dLdb = torch.matmul(ones, dLdY).view(-1) # dLdX = torch.matmul(dLdY, A) # rank = parallel_state.get_tensor_model_parallel_rank() # my_dLdA = torch.split(dLdA, output_size_coeff, # dim=0)[rank].contiguous().clone() # error = my_dLdA.sub(linear_layer.weight.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdA on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # my_dLdb = torch.split(dLdb, output_size_coeff, # dim=0)[rank].contiguous().clone() # error = my_dLdb.sub(linear_layer.bias.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdb on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # error = dLdX.sub(identity_layer.weight.grad).abs().max() # torch.distributed.barrier() # print(' error in dLdX on global rank {}: {}'.format( # torch.distributed.get_rank(), error)) # assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device): parallel_state.initialize_model_parallel(tensor_model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing initialize_affine_weight with model parallel ' 'size: {}'.format(tensor_model_parallel_size)) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) if device == 'cpu': layers._initialize_affine_weight_cpu( weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_, params_dtype=global_vars.get_args().params_dtype, ) else: layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = parallel_state.get_tensor_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' column parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) if device == 'cpu': layers._initialize_affine_weight_cpu( weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_, params_dtype=global_vars.get_args().params_dtype) else: layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = parallel_state.get_tensor_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' row parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def __init__( self, init_method, output_layer_init_method, layer_number, num_attention_heads, hidden_size, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding, precision=16, apply_query_key_layer_scaling=True, kv_channels=None, use_cpu_initialization=False, masked_softmax_fusion=True, attention_dropout=0.1, ): super(ParallelAttention, self).__init__() self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.attention_softmax_in_fp32 = False if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads projection_size = kv_channels * num_attention_heads # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = safe_divide(projection_size, world_size) self.hidden_size_per_attention_head = safe_divide( projection_size, num_attention_heads) self.num_attention_heads_per_partition = safe_divide( num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = tensor_parallel.ColumnParallelLinear( hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=use_cpu_initialization, ) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = tensor_parallel.ColumnParallelLinear( hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff fused_fp16 = precision == 16 fused_bf16 = precision == 'bf16' self.scale_mask_softmax = FusedScaleMaskSoftmax( fused_fp16, fused_bf16, self.attn_mask_type, masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(attention_dropout) # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, use_cpu_initialization=use_cpu_initialization, )
def __init__( self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, *, no_async_tensor_model_parallel_allreduce=False, params_dtype=torch.float32, use_cpu_initialization=False, gradient_accumulation_fusion=False, accumulation_in_fp16: bool = False, ): super(ColumnParallelLinear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. world_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. if use_cpu_initialization: self.weight = Parameter( torch.empty( self.output_size_per_partition, self.input_size, dtype=params_dtype ) ) self.master_weight = _initialize_affine_weight_cpu( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, stride=stride, return_master_weight=keep_master_weight_for_test, params_dtype=params_dtype, ) else: self.weight = Parameter( torch.empty( self.output_size_per_partition, self.input_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) _initialize_affine_weight_gpu( self.weight, init_method, partition_dim=0, stride=stride ) if bias: if use_cpu_initialization: self.bias = Parameter( torch.empty(self.output_size_per_partition, dtype=params_dtype) ) else: self.bias = Parameter( torch.empty( self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype, ) ) set_tensor_model_parallel_attributes(self.bias, True, 0, stride) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter("bias", None) self.async_tensor_model_parallel_allreduce = ( not no_async_tensor_model_parallel_allreduce and world_size > 1 ) if gradient_accumulation_fusion: if not _grad_accum_fusion_available: # Basically, apex.transformer module users are expected to install APEX's # `--cpp_ext` and `--cuda_ext`. The example installation command is as follows: # `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." # at the root of APEX repository. import warnings warnings.warn( "`gradient_accumulation_fusion` is set to `True` but " "the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not " "found. Thus `gradient_accumulation_fusion` set to `False`. " "Note that the extension requires CUDA>=11." ) gradient_accumulation_fusion = False self.gradient_accumulation_fusion = gradient_accumulation_fusion self._forward_impl = ( linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce )
def __init__( self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, *, params_dtype=torch.float32, use_cpu_initialization=False, ): super(RowParallelLinear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. world_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add # as an argument to this function? # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. if use_cpu_initialization: self.weight = Parameter( torch.empty( self.output_size, self.input_size_per_partition, dtype=params_dtype ) ) self.master_weight = _initialize_affine_weight_cpu( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, stride=stride, return_master_weight=keep_master_weight_for_test, params_dtype=params_dtype, ) else: self.weight = Parameter( torch.empty( self.output_size, self.input_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype, ) ) _initialize_affine_weight_gpu( self.weight, init_method, partition_dim=1, stride=stride ) if bias: if use_cpu_initialization: self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) else: self.bias = Parameter( torch.empty( self.output_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter("bias", None)