Esempio n. 1
0
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):

    if torch.distributed.get_rank() == 0:
        print('> testing get_tensor_model_parallel_src_rank with size {} ...'.
              format(tensor_model_parallel_size_))
    tensor_model_parallel_size = min(
        tensor_model_parallel_size_,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    assert parallel_state.model_parallel_is_initialized()

    # Checks
    src_rank = torch.distributed.get_rank(
    ) - parallel_state.get_tensor_model_parallel_rank()
    assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank is None

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 2
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path

        if os.path.isfile(restore_path):
            self._load_checkpoint(restore_path)
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                self._load_checkpoint(mp_restore_path)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')
Esempio n. 3
0
    def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
        if self.world_size < 4:
            self.skipTest("requires >= 4 GPUs")
        self.assertFalse(parallel_state.model_parallel_is_initialized())

        tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
        pipeline_model_parallel_world_size = (self.world_size //
                                              tensor_model_parallel_world_size)
        virtual_pipeline_model_parallel_world_size = 2
        pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size,
            pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
            virtual_pipeline_model_parallel_size_=
            virtual_pipeline_model_parallel_world_size,
            pipeline_model_parallel_split_rank_=
            pipeline_model_parallel_split_rank,
        )
        self.assertEqual(
            calc_expected_tensor_model_paralell_rank(
                self.rank, tensor_model_parallel_world_size),
            parallel_state.get_tensor_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_world_size,
            parallel_state.get_pipeline_model_parallel_world_size(),
        )
        self.assertEqual(
            virtual_pipeline_model_parallel_world_size,
            parallel_state.get_virtual_pipeline_model_parallel_world_size(),
        )

        expected_pipeline_rank = (self.rank -
                                  (self.rank % tensor_model_parallel_world_size
                                   )) % pipeline_model_parallel_world_size
        self.assertEqual(
            expected_pipeline_rank,
            parallel_state.get_pipeline_model_parallel_rank(),
        )
        # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
        # `initialize_model_parallel`, it's set to 0.
        self.assertEqual(
            0,
            parallel_state.get_virtual_pipeline_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank(),
        )

        fake_split_rank = 77
        parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
        self.assertEqual(
            fake_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank())

        parallel_state.destroy_model_parallel()
Esempio n. 4
0
    def test_initialize_model_parallel(self) -> None:

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size)

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=
                    pipeline_model_parallel_world_size,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size)
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(
                    parallel_state.model_parallel_is_initialized())
Esempio n. 5
0
def test_pipeline_model_parallel_split_rank():
    pipeline_model_parallel_split_rank_ = 1
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(
        pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_
    )
    assert parallel_state.model_parallel_is_initialized()

    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank is pipeline_model_parallel_split_rank_

    fake_split_rank = 7
    parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank == fake_split_rank

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 6
0
def test_initialize_model_parallel(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(
        tensor_model_parallel_size,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
    assert parallel_state.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == parallel_state.get_tensor_model_parallel_world_size()
    assert rank == parallel_state.get_tensor_model_parallel_rank()
    check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)

    # Data parallel.
    world_size = torch.distributed.get_world_size(
    ) // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
    assert world_size == parallel_state.get_data_parallel_world_size()
    assert rank == parallel_state.get_data_parallel_rank()
    check(parallel_state.get_data_parallel_group(), world_size, rank)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 7
0
def build_model(
    model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
    wrap_with_ddp: bool = True,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    model_type: ModelType = ModelType.encoder_or_decoder,
    *args: Any,
    **kwargs: Any,
) -> List[torch.nn.Module]:
    """Build the model satisfying pipeline model parallel requirements.

    This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
    `model_provider_func`.

    Args:
        model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
        wrap_with_ddp: If :obj:`True`, wrap the instantiated model
            with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
        virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
        model_type:
        *args: arguments for model provider func
        **kwargs: Keyword arguments for model provider func

    Returns:
        a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
        the list has multiple models, otherwise one.
    """
    if (parallel_state.get_pipeline_model_parallel_world_size() > 1
            and virtual_pipeline_model_parallel_size is not None):
        model = []
        for i in range(virtual_pipeline_model_parallel_size):
            cur_args = args
            cur_kwargs = kwargs
            parallel_state.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            this_model = model_provider_func(*cur_args, **cur_kwargs)
            model.append(this_model)
    else:
        cur_args = args
        cur_kwargs = kwargs
        if model_type == ModelType.encoder_or_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        elif model_type == ModelType.encoder_and_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            # `add_encoder` & `add_decoder` logic.
            add_encoder, add_decoder = True, True
            if parallel_state.get_pipeline_model_parallel_world_size() > 1:
                split_rank = parallel_state.get_pipeline_model_parallel_split_rank(
                )
                if split_rank is None:
                    raise RuntimeError(
                        "Split rank needs to be specified for model with both encoder and decoder."
                    )
                rank = parallel_state.get_pipeline_model_parallel_rank()
                world_size = parallel_state.get_pipeline_model_parallel_world_size(
                )
                pre_process = rank == 0 or rank == split_rank
                post_process = rank == (split_rank -
                                        1) or rank == (world_size - 1)
                add_encoder = parallel_state.is_pipeline_stage_before_split()
                add_decoder = parallel_state.is_pipeline_stage_after_split()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
                "add_encoder": add_encoder,
                "add_decoder": add_decoder,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        model.model_type = model_type

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if (parallel_state.model_parallel_is_initialized()
            and parallel_state.get_data_parallel_rank() == 0):
        msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_pipeline_model_parallel_rank(),
            _calc_number_of_params(model),
        )
        print(msg, flush=True)

    # GPU allocation.
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())

    if wrap_with_ddp:
        i = torch.cuda.current_device()
        model = [
            torch.nn.parallel.distributed.DistributedDataParallel(
                model_module,
                device_ids=[i],
                output_device=i,
                process_group=parallel_state.get_data_parallel_group(),
            ) for model_module in model
        ]
    return model