Beispiel #1
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    app_state.model_parallel_size)
                app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                logging.info(f'mp_rank: {app_state.model_parallel_rank}')
                logging.info(f'dp_rank: {app_state.data_parallel_rank}')
Beispiel #2
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            if torch.distributed.is_initialized():
                mpu.initialize_model_parallel(app_state.model_parallel_size)
                app_state.model_parallel_group = mpu.get_model_parallel_group()
                app_state.data_parallel_group = mpu.get_data_parallel_group()
                app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank(
                )
                app_state.data_parallel_rank = mpu.get_data_parallel_rank()
                app_state.data_parallel_size = mpu.get_data_parallel_world_size(
                )
                logging.info(f'mp_rank: {app_state.model_parallel_rank}')
                logging.info(f'dp_rank: {app_state.data_parallel_rank}')
                # TODO: get random seed from PTL
                seed = os.environ.get("PL_GLOBAL_SEED", 1234)
                # random seed must be set for megatron model parallel init
                _set_random_seed(seed)
Beispiel #3
0
    def init_ddp_connection(self,
                            global_rank: int,
                            world_size: int,
                            is_slurm_managing_tasks: bool = True) -> None:
        """ Override for LightningModule DDP initialization.
            Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        LightningModule.init_ddp_connection(self, global_rank, world_size,
                                            is_slurm_managing_tasks)

        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            if app_state.model_parallel_group is None:
                mpu.initialize_model_parallel(app_state.model_parallel_size)
                app_state.model_parallel_group = mpu.get_model_parallel_group()
                app_state.data_parallel_group = mpu.get_data_parallel_group()
                app_state.model_parallel_rank = torch.distributed.get_rank(
                    group=app_state.model_parallel_group)
                app_state.data_parallel_rank = torch.distributed.get_rank(
                    group=app_state.data_parallel_group)
                logging.info(f'mp_rank: {app_state.model_parallel_rank}')
                logging.info(f'dp_rank: {app_state.data_parallel_rank}')