def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): if torch.distributed.get_rank() == 0: print('> testing get_tensor_model_parallel_src_rank with size {} ...'. format(tensor_model_parallel_size_)) tensor_model_parallel_size = min( tensor_model_parallel_size_, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size) assert parallel_state.model_parallel_is_initialized() # Checks src_rank = torch.distributed.get_rank( ) - parallel_state.get_tensor_model_parallel_rank() assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank split_rank = parallel_state.get_pipeline_model_parallel_split_rank() assert split_rank is None # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads intermediate_size = 4 * hidden_size # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() transformer_layer = parallel_state.BertParallelTransformerLayer( hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, torch.nn.functional.relu, 1.0e-5).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = transformer_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ transformer_layer, identity_layer
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, dropout_prob, batch_size, sequence_length): parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() attention_layer = parallel_state.BertParallelSelfAttention( hidden_size, num_att_heads, dropout_prob).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = attention_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = parallel_state.get_tensor_model_parallel_rank() parallel_state.destroy_model_parallel() return rank, hidden_size, tensor_model_parallel_size, loss, \ attention_layer, identity_layer
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing model parallel cuda manual seed with size {} ...'. format(tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) tensor_parallel.random.model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with tensor_parallel.random.get_cuda_rng_tracker().fork(): assert (torch.cuda.initial_seed() == 12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()) # Reset the tracker tensor_parallel.random.get_cuda_rng_tracker().reset() # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_cross_entropy(self): batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11 logits_scale = 1000.0 seed = 1234 for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, ) vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size loss_torch, grad_torch = torch_cross_entropy( batch_size, sequence_length, vocab_size, logits_scale, seed) ( loss_tensor_parallel, grad_tensor_parallel, ) = tensor_sharded_cross_entropy(batch_size, sequence_length, vocab_size, logits_scale, seed) torch.testing.assert_close(loss_torch, loss_tensor_parallel) torch.testing.assert_close(grad_torch, grad_tensor_parallel) parallel_state.destroy_model_parallel()
def test_initialize_model_parallel_with_virtual_and_split(self) -> None: if self.world_size < 4: self.skipTest("requires >= 4 GPUs") self.assertFalse(parallel_state.model_parallel_is_initialized()) tensor_model_parallel_world_size = 1 + int(self.world_size > 4) pipeline_model_parallel_world_size = (self.world_size // tensor_model_parallel_world_size) virtual_pipeline_model_parallel_world_size = 2 pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_world_size, pipeline_model_parallel_split_rank_= pipeline_model_parallel_split_rank, ) self.assertEqual( calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size), parallel_state.get_tensor_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_world_size, parallel_state.get_pipeline_model_parallel_world_size(), ) self.assertEqual( virtual_pipeline_model_parallel_world_size, parallel_state.get_virtual_pipeline_model_parallel_world_size(), ) expected_pipeline_rank = (self.rank - (self.rank % tensor_model_parallel_world_size )) % pipeline_model_parallel_world_size self.assertEqual( expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(), ) # virtual pipeline model parallel rank is lazily set, i.e., right after the call of # `initialize_model_parallel`, it's set to 0. self.assertEqual( 0, parallel_state.get_virtual_pipeline_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_split_rank, parallel_state.get_pipeline_model_parallel_split_rank(), ) fake_split_rank = 77 parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) self.assertEqual( fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()) parallel_state.destroy_model_parallel()
def convert(local_rank, rank, world_size, args): app_state = AppState() app_state.data_parallel_rank = 0 num_nodes = world_size // args.gpus_per_node if args.bcp: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()]) else: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu') app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size app_state.tensor_model_parallel_size = args.tensor_model_parallel_size app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state.tensor_model_parallel_size, pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, ) app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) # inject model parallel rank checkpoint_path = inject_model_parallel_rank( os.path.join(args.checkpoint_folder, args.checkpoint_name)) logging.info( f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}' ) if args.model_type == 'gpt': model = MegatronGPTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'bert': model = MegatronBertModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 't5': model = MegatronT5Model.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'nmt': model = MegatronNMTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) model._save_restore_connector = NLPSaveRestoreConnector() if torch.distributed.is_initialized(): torch.distributed.barrier() model.save_to(args.nemo_file_path) logging.info(f'NeMo model saved to: {args.nemo_file_path}')
def test__reduce(args, tensor_model_parallel_size): print("Testing reduction size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) assert torch.equal( mappings._reduce(torch.full((10, 10, 10, 10), (50))), torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size), ) parallel_state.destroy_model_parallel() print("Passed!")
def test_broadcast_data(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing broadcast_data with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) torch.manual_seed(1234 + parallel_state.get_data_parallel_rank()) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12], } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(self): for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) seed_1, seed_2, size = 1234, 4321, [12, 21] tensor = torch.cuda.FloatTensor(size) torch.cuda.manual_seed(seed_1) torch.randn(size, out=tensor) target_11 = tensor.clone() torch.randn(size, out=tensor) target_12 = tensor.clone() torch.cuda.manual_seed(seed_2) torch.randn(size, out=tensor) targt_21 = tensor.clone() torch.randn(size, out=tensor) target_22 = tensor.clone() torch.cuda.manual_seed(seed_1) tensor_parallel.random.get_cuda_rng_tracker().add( "test", seed_2) torch.randn(size, out=tensor) result_11 = tensor.clone() with tensor_parallel.random.get_cuda_rng_tracker().fork( "test"): torch.randn(size, out=tensor) result_21 = tensor.clone() torch.randn(size, out=tensor) result_12 = tensor.clone() with tensor_parallel.random.get_cuda_rng_tracker().fork( "test"): torch.randn(size, out=tensor) result_22 = tensor.clone() self.assertEqual(target_11, result_11) self.assertEqual(target_12, result_12) self.assertEqual(targt_21, result_21) self.assertEqual(target_22, result_22) self.assertNotEqual(result_11, result_21) self.assertNotEqual(result_21, result_22) tensor_parallel.random.get_cuda_rng_tracker().reset() parallel_state.destroy_model_parallel()
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: for data_parallel_size in range(1, self.world_size + 1): expected_global_batch_size = self.GLOBAL_BATCH_SIZE expected_micro_batch_size = self.MICRO_BATCH_SIZE if rampup_batch_size: expected_global_batch_size = rampup_batch_size[0] num_consumed_samples = 0 step_of_global_batch_size = rampup_batch_size[1] threshold = rampup_batch_size[2] if data_parallel_size > 1 and data_parallel_size % 2 != 0: continue if self.world_size % data_parallel_size != 0: continue with self.subTest(data_parallel_size=data_parallel_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=self.world_size // data_parallel_size, pipeline_model_parallel_size_=1, ) self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) _reconfigure_microbatch_calculator( self.rank, rampup_batch_size, self.GLOBAL_BATCH_SIZE, self.MICRO_BATCH_SIZE, data_parallel_size, ) self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) self.assertEqual( get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(current_global_batch_size, expected_global_batch_size) # Make sure `global_batch_size` equals to the final global batch size after # certain number of updates. if rampup_batch_size: update_num_microbatches(current_global_batch_size) for i in range(100): current_global_batch_size = get_current_global_batch_size( ) update_num_microbatches(current_global_batch_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE) parallel_state.destroy_model_parallel()
def test_row_parallel_linear(self) -> None: for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size set_random_seed(self.SEED) linear_layer = layers.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, ).cuda() loss_weight = torch.randn( (self.BATCH_SIZE, output_size)).cuda() # Forward and backward input_tensor = torch.randn(self.BATCH_SIZE, input_size, requires_grad=True).cuda() input_tensor.retain_grad() output, _ = linear_layer(input_tensor) loss = torch.mul(output, loss_weight).sum() loss.backward() self.assertIsNotNone(input_tensor.grad) with torch.no_grad(): dldy = loss_weight.clone() x = input_tensor.clone() a = linear_layer.master_weight.cuda() dlda = torch.matmul(dldy.t(), x) dldb = torch.matmul( torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy).view(-1) dldx = torch.matmul(dldy, a) with torch.no_grad(): curr_dlda = torch.split( dlda, self.INPUT_SIZE_COEFF, dim=1 )[parallel_state.get_tensor_model_parallel_rank()].clone() self.assertEqual(linear_layer.weight.grad, curr_dlda) self.assertEqual(input_tensor.grad, dldx) self.assertEqual(linear_layer.bias.grad, dldb) parallel_state.destroy_model_parallel()
def test__gather(args, tensor_model_parallel_size): print("Testing gathering size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) assert torch.equal( mappings._gather( torch.tensor([parallel_state.get_tensor_model_parallel_rank()])), torch.tensor(list(range(tensor_model_parallel_size))), ) parallel_state.destroy_model_parallel() print("Passed!")
def _affine_weight_init_test_impl(self, init_device: str, is_column_parallel: bool) -> None: dim = int(not is_column_parallel) for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size weight_shape = ((self.OUTPUT_SIZE_COEFF, input_size) if is_column_parallel else (output_size, self.INPUT_SIZE_COEFF)) weight = torch.empty(weight_shape) set_random_seed(self.SEED) sharding_dim_size = (self.OUTPUT_SIZE_COEFF if is_column_parallel else self.INPUT_SIZE_COEFF) if init_device == "cpu": layers._initialize_affine_weight_cpu( weight, output_size, input_size, sharding_dim_size, dim, nn.init.normal_, params_dtype=torch.float32, ) else: layers._initialize_affine_weight_gpu( weight, torch.nn.init.normal_, dim) # Target set_random_seed(self.SEED) if init_device == "cpu": main_weight = torch.empty(output_size, input_size) nn.init.normal_(main_weight) curr_weight = torch.split( main_weight, sharding_dim_size, dim=dim)[ parallel_state.get_tensor_model_parallel_rank()] else: curr_weight = torch.empty(*weight_shape) nn.init.normal_(curr_weight) self.assertEqual(curr_weight, weight) parallel_state.destroy_model_parallel()
def test__split(args, tensor_model_parallel_size): print("Testing splitting size =", tensor_model_parallel_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) listy = [] for i in range(tensor_model_parallel_size): listy.append(torch.randn(10, 1)) x = torch.cat(tuple(listy), 1) out = mappings._split(x) assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()]) parallel_state.destroy_model_parallel() print("Passed!")
def test_column_parallel_linear_with_async_allreduce_custom_amp( tensor_model_parallel_size): dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else ( torch.half, ) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * tensor_model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * tensor_model_parallel_size batch_size = 7 for dtype in dtypes: # Network identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype) linear_layer = layers.ColumnParallelLinear( input_size, output_size, keep_master_weight_for_test=True, params_dtype=global_vars.get_args().params_dtype, use_cpu_initialization=global_vars.get_args(). use_cpu_initialization, ).to(device="cuda", dtype=dtype) # Forward loss_weight = torch.randn([batch_size, output_size]).cuda() output, _ = linear_layer(identity_layer()) loss = torch.mul(output, loss_weight).sum() loss.backward() torch.distributed.barrier() assert output.dtype == dtype # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_reduce(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") expected = torch.full( (10, 10, 10, 10), 50 * tensor_model_paralell_world_size, device=f"cuda:{self.rank}", ) self.assertTrue(torch.equal(mappings._reduce(t), expected)) parallel_state.destroy_model_parallel()
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )
def test_gather(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) device = f"cuda:{self.rank}" gathered = mappings._gather( torch.tensor( [parallel_state.get_tensor_model_parallel_rank()], device=device)) expected = torch.tensor( [rank for rank in range(tensor_model_paralell_world_size)], device=device, ) self.assertTrue(torch.equal(gathered, expected)) parallel_state.destroy_model_parallel()
def test_set_cuda_rng_state(self): for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) size, seed = 123, 1234 torch.cuda.manual_seed(seed) tensor = torch.cuda.FloatTensor(size) rng_state = torch.cuda.get_rng_state() rng_state_clone = rng_state.clone() for _ in range(5): torch.randn(size, out=tensor) result_1 = tensor.clone() self.assertEqual(rng_state.sub(rng_state_clone).max(), 0) self.assertGreater( torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0) new_rng_state = torch.cuda.get_rng_state() self.assertGreater(new_rng_state.sub(rng_state).max(), 0) tensor_parallel.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) tensor_parallel.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) result_2 = tensor.clone() torch.testing.assert_close(result_2, result_1) self.assertEqual(rng_state.sub(rng_state_clone).max(), 0) parallel_state.destroy_model_parallel()
def test_initialize_model_parallel(self) -> None: self.assertFalse(parallel_state.model_parallel_is_initialized()) for tensor_model_parallel_world_size in range(1, self.world_size + 1): with self.subTest(tensor_model_parallel_world_size= tensor_model_parallel_world_size): if self.world_size % tensor_model_parallel_world_size: continue pipeline_model_parallel_world_size = ( self.world_size // tensor_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_= tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, ) self.assertEqual( tensor_model_parallel_world_size, parallel_state.get_tensor_model_parallel_world_size(), ) expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size) self.assertEqual( expected_tensor_model_parallel_rank, parallel_state.get_tensor_model_parallel_rank(), ) expected_tensor_model_parallel_src_rank = ( self.rank // tensor_model_parallel_world_size ) * tensor_model_parallel_world_size self.assertEqual( expected_tensor_model_parallel_src_rank, parallel_state.get_tensor_model_parallel_src_rank(), ) parallel_state.destroy_model_parallel() self.assertFalse( parallel_state.model_parallel_is_initialized())
def test_split(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) tensors = [ torch.randn(10, 1) for rank in range(tensor_model_paralell_world_size) ] x = torch.cat(tensors, 1) out = mappings._split(x) self.assertTrue( torch.equal( out, tensors[ parallel_state.get_tensor_model_parallel_rank()])) parallel_state.destroy_model_parallel()
def test_cross_entropy(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing cross entropy with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) batch_size = 13 seq_length = 17 vocab_size_per_partition = 11 logits_scale = 1000.0 vocab_size = vocab_size_per_partition * tensor_model_parallel_size seed = 1234 loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) error = loss_torch.sub_(loss_mpu).abs().max() print(' max error in loss on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = grad_torch.sub_(grad_mpu).abs().max() print(' max error in grad on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_broadcast_data(self): tensor_model_parallel_world_size: int = self.world_size // ( 1 + self.world_size > 1) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size) target_key_size = { "key1": [7, 11], "key2": [8, 2, 1], "key3": [13], "key4": [5, 1, 2], "key5": [5, 12], } keys = [k for k in target_key_size] data = {} data_t = {} with torch.no_grad(): for key in target_key_size: data[key] = torch.randint(0, 1000, size=target_key_size[key]) data_t[key] = data[key].clone() # "key_x" is supposed to be ignored. data["key_x"] = torch.rand(5) data_t["key_x"] = data["key_x"].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, _, _ = data_utils._build_key_size_numel_dictionaries( keys, data) for key in keys: self.assertEqual(target_key_size[key], key_size[key]) broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: torch.testing.assert_close(broadcasted_data[key], data_t[key].cuda()) parallel_state.destroy_model_parallel()
def test_pipeline_model_parallel_split_rank(): pipeline_model_parallel_split_rank_ = 1 assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_ ) assert parallel_state.model_parallel_is_initialized() split_rank = parallel_state.get_pipeline_model_parallel_split_rank() assert split_rank is pipeline_model_parallel_split_rank_ fake_split_rank = 7 parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) split_rank = parallel_state.get_pipeline_model_parallel_split_rank() assert split_rank == fake_split_rank # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_initialize_model_parallel(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( tensor_model_parallel_size)) tensor_model_parallel_size_ = min( tensor_model_parallel_size, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size_) assert parallel_state.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = tensor_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_ assert world_size == parallel_state.get_tensor_model_parallel_world_size() assert rank == parallel_state.get_tensor_model_parallel_rank() check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size( ) // tensor_model_parallel_size_ rank = torch.distributed.get_rank() // tensor_model_parallel_size assert world_size == parallel_state.get_data_parallel_world_size() assert rank == parallel_state.get_data_parallel_rank() check(parallel_state.get_data_parallel_group(), world_size, rank) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def test_split_tensor_along_last_dim(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue with self.subTest(tensor_model_paralell_world_size= tensor_model_paralell_world_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) device = "cpu" input_tensor = torch.randn((100, 100, 100), device=device) splits = utils.split_tensor_along_last_dim(input_tensor, 10) last_dim_shapes = torch.tensor( [int(split.size()[-1]) for split in splits]) self.assertTrue( torch.equal( last_dim_shapes, torch.full((10, ), 10), )) parallel_state.destroy_model_parallel()
batch_size = args.global_batch_size micro_batch_size = args.micro_batch_size setup_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, args.data_parallel_size, # args.data_parallel_size, ) world_size = torch.distributed.get_world_size() print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE") parallel_state.initialize_model_parallel( tensor_model_parallel_size_=args.tensor_model_parallel_size, pipeline_model_parallel_size_=args.pipeline_model_parallel_size, default_backend="nccl", p2p_backend="ucc" if HAS_TORCH_UCC else "nccl", ) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size() ) model_parallel_cuda_manual_seed(0) model = build_model( gpt_model_provider, wrap_with_ddp=True, virtual_pipeline_model_parallel_size=None, cpu_offload=args.cpu_offload, ) assert isinstance(model, list), model _param_groups = _get_params_for_weight_decay_optimization(model)
def _forward_backward_test_impl( self, forward_only: bool, fwd_bwd_func: FwdStepFunc, pipeline_model_parallel_world_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int], async_comm: bool = False, *, default_backend: Optional[str] = None, p2p_backend: Optional[str] = None, ) -> None: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertGreater(virtual_pipeline_model_parallel_size, 1) dtype_options = self.dtypes or [torch.float32, torch.double ] + _get_autocast_dtypes() for dtype, deallocate_pipeline_outputs in itertools.product( dtype_options, self.deallocate_options, ): grad_scaler = (torch.cuda.amp.GradScaler( init_scale=4.0) if dtype == torch.half else None) (tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size ) = _get_default_world_sizes_model_parallel_world_size( pipeline_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_size, default_backend=default_backend, p2p_backend=p2p_backend, ) pp_utils._reconfigure_microbatch_calculator( rank=parallel_state.get_tensor_model_parallel_rank(), rampup_batch_size=None, global_batch_size=self.GLOBAL_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) global_batch_shape = ( self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), self.HIDDEN_SIZE, self.HIDDEN_SIZE, ) batch = None if parallel_state.is_pipeline_first_stage(): batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) model = build_model( testing_utils.model_provider_func, # Use DDP only when it's better to have wrap_with_ddp=data_parallel_size > 1, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=self.HIDDEN_SIZE, ) offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 for idx, model_module in enumerate(model): model_module = model_module.to(dtype) model_module.apply(get_init_weights_func(idx * offset)) _param_groups = _get_params_for_weight_decay_optimization(model) optimizer = torch.optim.Adam(_param_groups, lr=1e-3) pp_utils.update_num_microbatches(0) loss = fwd_bwd_func( testing_utils.fwd_step_func, batch, model, forward_only=forward_only, # `tensor_shape` is the shape of micro batch. tensor_shape=( self.MICRO_BATCH_SIZE, self.HIDDEN_SIZE, self.HIDDEN_SIZE, ), dtype=dtype, async_comm=async_comm, grad_scaler=grad_scaler, deallocate_pipeline_output=deallocate_pipeline_outputs, ) if dtype == torch.double: hidden_size = self.HIDDEN_SIZE microbatch_size = self.MICRO_BATCH_SIZE total_layers = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None: total_layers *= virtual_pipeline_model_parallel_size target_loss, target_model = get_target_loss_and_model( global_batch_shape, hidden_size, total_layers) for loss_item in loss: x = loss_item['avg'] torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) if not forward_only: for vm_id, model_module in enumerate(model): params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] torch.testing.assert_close(params[0].cpu(), target_params[0]) torch.testing.assert_close(params[1].cpu(), target_params[1]) torch.testing.assert_close( params[0].grad.cpu() / microbatch_size, target_params[0].grad) torch.testing.assert_close( params[1].grad.cpu() / microbatch_size, target_params[1].grad) if not forward_only: for m in model: for p in m.parameters(): self.assertIsNotNone(p.grad) optimizer.step() optimizer.zero_grad(set_to_none=True) parallel_state.destroy_model_parallel()