def model_parallel_cuda_manual_seed(seed): """Initialize model parallel cuda seed. This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. """ # 2718 is just for fun and any POSITIVE value will work. offset = seed + 2718 tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() # Data parallel gets the original seed. data_parallel_seed = seed _CUDA_RNG_STATE_TRACKER.reset() # Set the default state. torch.cuda.manual_seed(data_parallel_seed) # and model parallel state. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): if torch.distributed.get_rank() == 0: print('> testing get_tensor_model_parallel_src_rank with size {} ...'. format(tensor_model_parallel_size_)) tensor_model_parallel_size = min( tensor_model_parallel_size_, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size) assert parallel_state.model_parallel_is_initialized() # Checks src_rank = torch.distributed.get_rank( ) - parallel_state.get_tensor_model_parallel_rank() assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank split_rank = parallel_state.get_pipeline_model_parallel_split_rank() assert split_rank is None # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, dropout_prob, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() attention_layer = parallel_state.BertParallelSelfAttention( hidden_size, num_att_heads, dropout_prob).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = attention_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ attention_layer, identity_layer
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing model parallel cuda manual seed with size {} ...'. format(tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) tensor_parallel.random.model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with tensor_parallel.random.get_cuda_rng_tracker().fork(): assert (torch.cuda.initial_seed() == 12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()) # Reset the tracker tensor_parallel.random.get_cuda_rng_tracker().reset() # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads intermediate_size = 4 * hidden_size # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() transformer_layer = parallel_state.BertParallelTransformerLayer( hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, torch.nn.functional.relu, 1.0e-5).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = transformer_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ transformer_layer, identity_layer
def forward(ctx, vocab_parallel_logits, target): # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze( dim=-1) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size()[-1] rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits # Store softmax, target-mask and masked-target for backward pass. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) return loss
def test_initialize_model_parallel_with_virtual_and_split(self) -> None: if self.world_size < 4: self.skipTest("requires >= 4 GPUs") self.assertFalse(parallel_state.model_parallel_is_initialized()) tensor_model_parallel_world_size = 1 + int(self.world_size > 4) pipeline_model_parallel_world_size = (self.world_size // tensor_model_parallel_world_size) virtual_pipeline_model_parallel_world_size = 2 pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_world_size, pipeline_model_parallel_split_rank_= pipeline_model_parallel_split_rank, ) self.assertEqual( calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size), parallel_state.get_tensor_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_world_size, parallel_state.get_pipeline_model_parallel_world_size(), ) self.assertEqual( virtual_pipeline_model_parallel_world_size, parallel_state.get_virtual_pipeline_model_parallel_world_size(), ) expected_pipeline_rank = (self.rank - (self.rank % tensor_model_parallel_world_size )) % pipeline_model_parallel_world_size self.assertEqual( expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(), ) # virtual pipeline model parallel rank is lazily set, i.e., right after the call of # `initialize_model_parallel`, it's set to 0. self.assertEqual( 0, parallel_state.get_virtual_pipeline_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_split_rank, parallel_state.get_pipeline_model_parallel_split_rank(), ) fake_split_rank = 77 parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) self.assertEqual( fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()) parallel_state.destroy_model_parallel()
def split_tensor_into_1d_equal_chunks(tensor): """Break a tensor into equal 1D chunks.""" data = tensor.view(-1) partition_size = torch.numel( data) // parallel_state.get_tensor_model_parallel_world_size() start_index = partition_size * parallel_state.get_tensor_model_parallel_rank( ) end_index = start_index + partition_size return data[start_index:end_index]
def convert(local_rank, rank, world_size, args): app_state = AppState() app_state.data_parallel_rank = 0 num_nodes = world_size // args.gpus_per_node if args.bcp: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()]) else: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu') app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size app_state.tensor_model_parallel_size = args.tensor_model_parallel_size app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state.tensor_model_parallel_size, pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, ) app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) # inject model parallel rank checkpoint_path = inject_model_parallel_rank( os.path.join(args.checkpoint_folder, args.checkpoint_name)) logging.info( f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}' ) if args.model_type == 'gpt': model = MegatronGPTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'bert': model = MegatronBertModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 't5': model = MegatronT5Model.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'nmt': model = MegatronNMTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) model._save_restore_connector = NLPSaveRestoreConnector() if torch.distributed.is_initialized(): torch.distributed.barrier() model.save_to(args.nemo_file_path) logging.info(f'NeMo model saved to: {args.nemo_file_path}')
def __init__( self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False, ): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim # Set the detauls for compatibility. self.padding_idx = None self.max_norm = None self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = get_tensor_model_parallel_world_size( ) # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. if use_cpu_initialization: self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)) _initialize_affine_weight_cpu( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, params_dtype=params_dtype, ) else: self.weight = Parameter( torch.empty( self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype, )) _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
def test_broadcast_data(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing broadcast_data with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) torch.manual_seed(1234 + parallel_state.get_data_parallel_rank()) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12], } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_row_parallel_linear(self) -> None: for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size set_random_seed(self.SEED) linear_layer = layers.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, ).cuda() loss_weight = torch.randn( (self.BATCH_SIZE, output_size)).cuda() # Forward and backward input_tensor = torch.randn(self.BATCH_SIZE, input_size, requires_grad=True).cuda() input_tensor.retain_grad() output, _ = linear_layer(input_tensor) loss = torch.mul(output, loss_weight).sum() loss.backward() self.assertIsNotNone(input_tensor.grad) with torch.no_grad(): dldy = loss_weight.clone() x = input_tensor.clone() a = linear_layer.master_weight.cuda() dlda = torch.matmul(dldy.t(), x) dldb = torch.matmul( torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy).view(-1) dldx = torch.matmul(dldy, a) with torch.no_grad(): curr_dlda = torch.split( dlda, self.INPUT_SIZE_COEFF, dim=1 )[parallel_state.get_tensor_model_parallel_rank()].clone() self.assertEqual(linear_layer.weight.grad, curr_dlda) self.assertEqual(input_tensor.grad, dldx) self.assertEqual(linear_layer.bias.grad, dldb) parallel_state.destroy_model_parallel()
def test__gather(args, tensor_model_parallel_size): print("Testing gathering size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) assert torch.equal( mappings._gather( torch.tensor([parallel_state.get_tensor_model_parallel_rank()])), torch.tensor(list(range(tensor_model_parallel_size))), ) parallel_state.destroy_model_parallel() print("Passed!")
def test__split(args, tensor_model_parallel_size): print("Testing splitting size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) listy = [] for i in range(tensor_model_parallel_size): listy.append(torch.randn(10, 1)) x = torch.cat(tuple(listy), 1) out = mappings._split(x) assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()]) parallel_state.destroy_model_parallel() print("Passed!")
def _affine_weight_init_test_impl(self, init_device: str, is_column_parallel: bool) -> None: dim = int(not is_column_parallel) for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size weight_shape = ((self.OUTPUT_SIZE_COEFF, input_size) if is_column_parallel else (output_size, self.INPUT_SIZE_COEFF)) weight = torch.empty(weight_shape) set_random_seed(self.SEED) sharding_dim_size = (self.OUTPUT_SIZE_COEFF if is_column_parallel else self.INPUT_SIZE_COEFF) if init_device == "cpu": layers._initialize_affine_weight_cpu( weight, output_size, input_size, sharding_dim_size, dim, nn.init.normal_, params_dtype=torch.float32, ) else: layers._initialize_affine_weight_gpu( weight, torch.nn.init.normal_, dim) # Target set_random_seed(self.SEED) if init_device == "cpu": main_weight = torch.empty(output_size, input_size) nn.init.normal_(main_weight) curr_weight = torch.split( main_weight, sharding_dim_size, dim=dim)[ parallel_state.get_tensor_model_parallel_rank()] else: curr_weight = torch.empty(*weight_shape) nn.init.normal_(curr_weight) self.assertEqual(curr_weight, weight) parallel_state.destroy_model_parallel()
def generate_fancy_data_labels(sequence_len, batch_size): global data_idx global inds global masks global MANUAL_SEED temps = [] for i in range(batch_size): if inds is None or data_idx >= len(inds): # hack as use of RNG will fall out of sync due to pipelines being different torch.manual_seed(MANUAL_SEED) inds = torch.randperm(effective_length, device="cuda") masks = ( torch.rand( len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda" ) >= MASK_PROB ).long() MANUAL_SEED += 1 print("new epoch", len(inds)) data_idx = 0 print("my start", inds[0:5]) print("masks_checksum:", torch.sum(masks)) if EASY_MODE: data_idx_ = data_idx % EASY_MODE_SIZ else: data_idx_ = data_idx offset = inds[data_idx_] # * SEQUENCE_LEN data_idx += 1 curr = fancy_data[offset : offset + sequence_len].clone().detach() temps.append(curr) temp = torch.stack(temps, dim=0).cuda() mask = masks[data_idx // batch_size] mask_not = torch.logical_not(mask).long() data = mask * temp + mask_not * 124 label = temp if parallel_state.get_tensor_model_parallel_rank() == 0: data_dict = {"text": data, "label": label, "mask_not": mask_not} else: data_dict = None keys = ["text", "label", "mask_not"] dtype = torch.int64 broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long) return ( broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"], )
def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" world_size = get_tensor_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Split along last dimension. input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. rank = get_tensor_model_parallel_rank() output = input_list[rank].contiguous() return output
def _build_key_size_numel_dictionaries(keys, data): """Build the size on rank 0 and broadcast.""" max_dim = _MAX_DATA_DIM sizes = [0 for _ in range(max_dim) for _ in keys] # Pack the sizes on rank zero. if get_tensor_model_parallel_rank() == 0: offset = 0 for key in keys: assert data[key].dim( ) < max_dim, "you should increase MAX_DATA_DIM" size = data[key].size() for i, s in enumerate(size): sizes[i + offset] = s offset += max_dim # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) torch.distributed.broadcast( sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() key_size = {} key_numel = {} total_numel = 0 offset = 0 for key in keys: i = 0 size = [] numel = 1 while sizes_cpu[offset + i] > 0: this_size = sizes_cpu[offset + i] size.append(this_size) numel *= this_size i += 1 key_size[key] = size key_numel[key] = numel total_numel += numel offset += max_dim return key_size, key_numel, total_numel
def _initialize_affine_weight_cpu( weight, output_size, input_size, per_partition_size, partition_dim, init_method, stride=1, return_master_weight=False, *, params_dtype=torch.float32, ): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk.""" set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) init_method(master_weight) master_weight = master_weight.to(dtype=params_dtype) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): torch.cat(my_weight_list, dim=partition_dim, out=weight) if return_master_weight: return master_weight return None
def broadcast_data(keys, data, datatype): """Broadcast data from rank zero of each model parallel group to the members of the same model parallel group. Arguments: keys: list of keys in the data disctionary to be broadcasted data: data dictionary of string keys and cpu tensor values. datatype: torch data type of all tensors in data associated with keys. """ # Build (key, size) and (key, number of elements) dictionaries along # with the total number of elements on all ranks. key_size, key_numel, total_numel = _build_key_size_numel_dictionaries( keys, data) # Pack on rank zero. if get_tensor_model_parallel_rank() == 0: # Check that all keys have the same data type. _check_data_types(keys, data, datatype) # Flatten the data associated with the keys flatten_data = torch.cat( [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() else: flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast torch.distributed.broadcast( flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), ) # Unpack output = {} offset = 0 for key in keys: size = key_size[key] numel = key_numel[key] output[key] = flatten_data.narrow(0, offset, numel).view(size) offset += numel return output
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )
def test_gather(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) device = f"cuda:{self.rank}" gathered = mappings._gather( torch.tensor( [parallel_state.get_tensor_model_parallel_rank()], device=device)) expected = torch.tensor( [rank for rank in range(tensor_model_paralell_world_size)], device=device, ) self.assertTrue(torch.equal(gathered, expected)) parallel_state.destroy_model_parallel()
def test_initialize_model_parallel(self) -> None: self.assertFalse(parallel_state.model_parallel_is_initialized()) for tensor_model_parallel_world_size in range(1, self.world_size + 1): with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): if self.world_size % tensor_model_parallel_world_size: continue pipeline_model_parallel_world_size = ( self.world_size // tensor_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, ) self.assertEqual( tensor_model_parallel_world_size, parallel_state.get_tensor_model_parallel_world_size(), ) expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size) self.assertEqual( expected_tensor_model_parallel_rank, parallel_state.get_tensor_model_parallel_rank(), ) expected_tensor_model_parallel_src_rank = ( self.rank // tensor_model_parallel_world_size ) * tensor_model_parallel_world_size self.assertEqual( expected_tensor_model_parallel_src_rank, parallel_state.get_tensor_model_parallel_src_rank(), ) parallel_state.destroy_model_parallel() self.assertFalse( parallel_state.model_parallel_is_initialized())
def test_split(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) tensors = [ torch.randn(10, 1) for rank in range(tensor_model_paralell_world_size) ] x = torch.cat(tensors, 1) out = mappings._split(x) self.assertTrue( torch.equal( out, tensors[ parallel_state.get_tensor_model_parallel_rank()])) parallel_state.destroy_model_parallel()
def test_broadcast_data(self): tensor_model_parallel_world_size: int = self.world_size // ( 1 + self.world_size > 1) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size) target_key_size = { "key1": [7, 11], "key2": [8, 2, 1], "key3": [13], "key4": [5, 1, 2], "key5": [5, 12], } keys = [k for k in target_key_size] data = {} data_t = {} with torch.no_grad(): for key in target_key_size: data[key] = torch.randint(0, 1000, size=target_key_size[key]) data_t[key] = data[key].clone() # "key_x" is supposed to be ignored. data["key_x"] = torch.rand(5) data_t["key_x"] = data["key_x"].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, _, _ = data_utils._build_key_size_numel_dictionaries( keys, data) for key in keys: self.assertEqual(target_key_size[key], key_size[key]) broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: torch.testing.assert_close(broadcasted_data[key], data_t[key].cuda()) parallel_state.destroy_model_parallel()
def _gather(input_): """Gather tensors and concatinate along the last dimension.""" world_size = get_tensor_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Size and dimension. last_dim = input_.dim() - 1 rank = get_tensor_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() return output
def test_initialize_model_parallel(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( tensor_model_parallel_size)) tensor_model_parallel_size_ = min( tensor_model_parallel_size, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size_) assert parallel_state.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = tensor_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_ assert world_size == parallel_state.get_tensor_model_parallel_world_size() assert rank == parallel_state.get_tensor_model_parallel_rank() check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size( ) // tensor_model_parallel_size_ rank = torch.distributed.get_rank() // tensor_model_parallel_size assert world_size == parallel_state.get_data_parallel_world_size() assert rank == parallel_state.get_data_parallel_rank() check(parallel_state.get_data_parallel_group(), world_size, rank) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def main(cfg) -> None: # trainer required for restoring model parallel models trainer = Trainer(plugins=NLPDDPPlugin(), **cfg.trainer) assert ( cfg.trainer.devices * cfg.trainer.num_nodes == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" # Load prompt tuned model, virtual_prompt_model_file must be provided in config if cfg.get('virtual_prompt_model_file', None) is not None: # Update frozen GPT model path in case it has changed prompt_learning_cfg = MegatronGPTPromptLearningModel.restore_from( cfg.virtual_prompt_model_file, trainer=trainer, return_config=True) with open_dict(prompt_learning_cfg): prompt_learning_cfg.language_model_path = cfg.gpt_model_file # Now load prompt learning model with frozen gpt model base model = MegatronGPTPromptLearningModel.restore_from( restore_path=cfg.virtual_prompt_model_file, trainer=trainer, override_config_path=prompt_learning_cfg) # Or load regular GPT model elif cfg.gpt_model_file: model = MegatronGPTModel.restore_from(restore_path=cfg.gpt_model_file, trainer=trainer) elif cfg.checkpoint_dir: app_state = AppState() if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, tensor_model_parallel_size_=cfg.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=cfg. pipeline_model_parallel_split_rank, ) checkpoint_path = inject_model_parallel_rank( os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) model = MegatronGPTModel.load_from_checkpoint( checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) else: raise ValueError("need at least a nemo file or checkpoint dir") model.freeze() # Have to turn off activations_checkpoint_method for inference try: model.model.language_model.encoder.activations_checkpoint_method = None except AttributeError: pass try: model.frozen_model.language_model.encoder.activations_checkpoint_method = None except AttributeError: pass length_params: LengthParam = { "max_length": cfg.inference.tokens_to_generate, "min_length": cfg.inference.min_tokens_to_generate, } sampling_params: SamplingParam = { "use_greedy": cfg.inference.greedy, "temperature": cfg.inference.temperature, "top_k": cfg.inference.top_k, "top_p": cfg.inference.top_p, "repetition_penalty": cfg.inference.repetition_penalty, "add_BOS": cfg.inference.add_BOS, "all_probs": cfg.inference.all_probs, "compute_logprob": cfg.inference.compute_logprob, } # First method of running text generation, call model.generate method response = model.generate(inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params) print("***************************") print(response) print("***************************") # Second method of running text generation, call trainer.predict collate_fn = None if cfg.get('virtual_prompt_model', False): collate_fn = lambda x: list(x) ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) request_dl = DataLoader(dataset=ds, collate_fn=collate_fn, batch_size=2) config = OmegaConf.to_container(cfg.inference) model.set_inference_config(config) response = trainer.predict(model, request_dl) print("***************************") print(response) print("***************************") # Third method of running text generation, use inference server if cfg.server: if parallel_state.is_pipeline_first_stage( ) and parallel_state.get_tensor_model_parallel_rank() == 0: server = MegatronServer(model.cuda()) server.run("0.0.0.0", port=cfg.port) while True: choice = torch.cuda.LongTensor(1) torch.distributed.broadcast(choice, 0) if choice[0].item() == 0: generate(model.cuda())
def _forward_backward_test_impl( self, forward_only: bool, fwd_bwd_func: FwdStepFunc, pipeline_model_parallel_world_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int], async_comm: bool = False, *, default_backend: Optional[str] = None, p2p_backend: Optional[str] = None, ) -> None: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertGreater(virtual_pipeline_model_parallel_size, 1) dtype_options = self.dtypes or [torch.float32, torch.double ] + _get_autocast_dtypes() for dtype, deallocate_pipeline_outputs in itertools.product( dtype_options, self.deallocate_options, ): grad_scaler = (torch.cuda.amp.GradScaler( init_scale=4.0) if dtype == torch.half else None) (tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size ) = _get_default_world_sizes_model_parallel_world_size( pipeline_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_size, default_backend=default_backend, p2p_backend=p2p_backend, ) pp_utils._reconfigure_microbatch_calculator( rank=parallel_state.get_tensor_model_parallel_rank(), rampup_batch_size=None, global_batch_size=self.GLOBAL_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) global_batch_shape = ( self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), self.HIDDEN_SIZE, self.HIDDEN_SIZE, ) batch = None if parallel_state.is_pipeline_first_stage(): batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) model = build_model( testing_utils.model_provider_func, # Use DDP only when it's better to have wrap_with_ddp=data_parallel_size > 1, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=self.HIDDEN_SIZE, ) offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 for idx, model_module in enumerate(model): model_module = model_module.to(dtype) model_module.apply(get_init_weights_func(idx * offset)) _param_groups = _get_params_for_weight_decay_optimization(model) optimizer = torch.optim.Adam(_param_groups, lr=1e-3) pp_utils.update_num_microbatches(0) loss = fwd_bwd_func( testing_utils.fwd_step_func, batch, model, forward_only=forward_only, # `tensor_shape` is the shape of micro batch. tensor_shape=( self.MICRO_BATCH_SIZE, self.HIDDEN_SIZE, self.HIDDEN_SIZE, ), dtype=dtype, async_comm=async_comm, grad_scaler=grad_scaler, deallocate_pipeline_output=deallocate_pipeline_outputs, ) if dtype == torch.double: hidden_size = self.HIDDEN_SIZE microbatch_size = self.MICRO_BATCH_SIZE total_layers = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None: total_layers *= virtual_pipeline_model_parallel_size target_loss, target_model = get_target_loss_and_model( global_batch_shape, hidden_size, total_layers) for loss_item in loss: x = loss_item['avg'] torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) if not forward_only: for vm_id, model_module in enumerate(model): params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] torch.testing.assert_close(params[0].cpu(), target_params[0]) torch.testing.assert_close(params[1].cpu(), target_params[1]) torch.testing.assert_close( params[0].grad.cpu() / microbatch_size, target_params[0].grad) torch.testing.assert_close( params[1].grad.cpu() / microbatch_size, target_params[1].grad) if not forward_only: for m in model: for p in m.parameters(): self.assertIsNotNone(p.grad) optimizer.step() optimizer.zero_grad(set_to_none=True) parallel_state.destroy_model_parallel()