def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): mpu.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = mpu.get_model_parallel_group() app_state.data_parallel_group = mpu.get_data_parallel_group() app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = mpu.get_data_parallel_rank() app_state.data_parallel_size = mpu.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}') # TODO: get random seed from PTL seed = os.environ.get("PL_GLOBAL_SEED", 1234) # random seed must be set for megatron model parallel init _set_random_seed(seed)
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )