def allreduce_word_and_position_embeddings(self):

        # Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407
        # All-reduce word_embeddings' grad across first and last stages to ensure
        # that word_embeddings parameters stay in sync.
        # This should only run for models that support pipelined model parallelism
        # (BERT and GPT-2).
        if parallel_state.get_pipeline_model_parallel_world_size() > 1 and (
            parallel_state.is_rank_in_embedding_group()
        ):
            if self.enc_dec_model.share_word_embeddings:
                word_embeddings_weight = self.enc_dec_model.word_embeddings_weight()
                if self.megatron_amp_o2:
                    # O2 recipe stores a "main" copy of weights and grads
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

                # All reduce position embeddings for T5.
                if (
                    parallel_state.is_rank_in_position_embedding_group()
                    and parallel_state.get_pipeline_model_parallel_world_size() > 1
                    and parallel_state.get_pipeline_model_parallel_split_rank() is not None
                ):
                    position_embeddings_weight = self.enc_dec_model.position_embeddings_weight()
                    if self.megatron_amp_o2:
                        grad = position_embeddings_weight.main_grad
                    else:
                        grad = position_embeddings_weight.grad
                    torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
Ejemplo n.º 2
0
    def setup(self, stage=None):
        # NOTE: super().__init__ will try and setup train/val/test datasets, but we sidestep this using a if self._train_ds is not None condition
        # We then set things up for real only once setup() of this class is called.
        resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
        if resume_checkpoint_path:
            try:
                init_consumed_samples = int(
                    float(
                        re.findall(r"consumed_samples\=([0-9]+.[0-9]+)",
                                   resume_checkpoint_path)[0]))
            except (ValueError, TypeError):
                logging.warning(
                    "Cannot parse the checkpoint file to get the consumed samples. This is expected if you are not using memmap datasets."
                )
                init_consumed_samples = 0
        else:
            init_consumed_samples = 0
        self.init_consumed_samples = init_consumed_samples
        if stage == 'predict':
            return

        # If the user wants to manually override train and validation dataloaders before calling `.fit()`
        if self._train_dl is not None and self._validation_dl is not None:
            return
        self.build_train_valid_test_datasets()
        self.setup_training_data(self._cfg.train_ds)
        self.setup_validation_data(self._cfg.validation_ds)
        if hasattr(self._cfg, 'test_ds'):
            self.setup_test_data(self._cfg.test_ds)

        # when using pipeline model parallel the final stage need to initialize word embeddings
        if parallel_state.get_pipeline_model_parallel_world_size() > 1:
            self.enc_dec_model.sync_initial_word_embeddings()
            self.enc_dec_model.sync_initial_position_embeddings()
Ejemplo n.º 3
0
    def initialize_word_embeddings(self, init_method, vocab_size, hidden_size):
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')

        # This function just initializes the word embeddings in the final stage
        # when we are using pipeline parallelism. If we aren't using pipeline
        # parallelism there is nothing to do.
        if parallel_state.get_pipeline_model_parallel_world_size() == 1:
            return

        # Parameters are shared between the word embeddings layer, and the
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if parallel_state.is_pipeline_last_stage():
            assert not parallel_state.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
            self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
                vocab_size, hidden_size, init_method=init_method)
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True
    def setup(self, stage=None):
        resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
        if resume_checkpoint_path:
            try:
                init_consumed_samples = int(
                    float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0])
                )
            except (ValueError, TypeError):
                logging.warning("Cannot parse the checkpoint file to get the consumed samples. assume it is zero.")
                init_consumed_samples = 0
        else:
            init_consumed_samples = 0
        self.init_consumed_samples = init_consumed_samples

        """A PTL method to setup the training, validation and test datasets."""
        if stage == 'predict':
            return
        if self._train_dl is not None and self._validation_dl is not None:
            return
        self.build_train_valid_test_datasets()
        self.setup_training_data(self._cfg.data)
        self.setup_validation_data(self._cfg.data)
        self.setup_test_data(self._cfg.data)

        # when using pipeline model parallel the final stage need to initialize word embeddings
        if parallel_state.get_pipeline_model_parallel_world_size() > 1:
            self.enc_dec_model.sync_initial_word_embeddings()
            self.enc_dec_model.sync_initial_position_embeddings()
 def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
     """Helper function to get the model chunk ID given the iteration number."""
     pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
     microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
     model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
     if not forward:
         model_chunk_id = num_model_chunks - model_chunk_id - 1
     return model_chunk_id
Ejemplo n.º 6
0
    def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
        if self.world_size < 4:
            self.skipTest("requires >= 4 GPUs")
        self.assertFalse(parallel_state.model_parallel_is_initialized())

        tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
        pipeline_model_parallel_world_size = (self.world_size //
                                              tensor_model_parallel_world_size)
        virtual_pipeline_model_parallel_world_size = 2
        pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size,
            pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
            virtual_pipeline_model_parallel_size_=
            virtual_pipeline_model_parallel_world_size,
            pipeline_model_parallel_split_rank_=
            pipeline_model_parallel_split_rank,
        )
        self.assertEqual(
            calc_expected_tensor_model_paralell_rank(
                self.rank, tensor_model_parallel_world_size),
            parallel_state.get_tensor_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_world_size,
            parallel_state.get_pipeline_model_parallel_world_size(),
        )
        self.assertEqual(
            virtual_pipeline_model_parallel_world_size,
            parallel_state.get_virtual_pipeline_model_parallel_world_size(),
        )

        expected_pipeline_rank = (self.rank -
                                  (self.rank % tensor_model_parallel_world_size
                                   )) % pipeline_model_parallel_world_size
        self.assertEqual(
            expected_pipeline_rank,
            parallel_state.get_pipeline_model_parallel_rank(),
        )
        # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
        # `initialize_model_parallel`, it's set to 0.
        self.assertEqual(
            0,
            parallel_state.get_virtual_pipeline_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank(),
        )

        fake_split_rank = 77
        parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
        self.assertEqual(
            fake_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank())

        parallel_state.destroy_model_parallel()
Ejemplo n.º 7
0
def get_forward_backward_func(
    virtual_pipeline_model_parallel_size,
    pipeline_model_parallel_size,
):
    if parallel_state.get_pipeline_model_parallel_world_size() > 1:
        if virtual_pipeline_model_parallel_size is not None:
            if get_num_microbatches() % pipeline_model_parallel_size != 0:
                msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
                raise RuntimeError(msg)
            forward_backward_func = _forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func
Ejemplo n.º 8
0
def get_forward_backward_func(
    virtual_pipeline_model_parallel_size,
    pipeline_model_parallel_size,
):
    if parallel_state.get_pipeline_model_parallel_world_size() > 1:
        if virtual_pipeline_model_parallel_size is not None:
            if get_num_microbatches() % pipeline_model_parallel_size != 0:
                msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
                raise RuntimeError(msg)
            warnings.warn(
                "Pipeline Model Parallel with interleaving scheduling is experimental. "
                f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
                ExperimentalWarning)
            forward_backward_func = _forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func
Ejemplo n.º 9
0
    def setup(self, stage=None):
        """ PTL hook that is executed after DDP spawns.
            We setup datasets here as megatron datasets require DDP to instantiate.
            See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
        Args:
            stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None.
        """
        resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
        if resume_checkpoint_path:
            try:
                init_consumed_samples = int(
                    float(
                        re.findall(r"consumed_samples\=([0-9]+.[0-9]+)",
                                   resume_checkpoint_path)[0]))
            except (ValueError, TypeError):
                logging.warning(
                    "Cannot parse the checkpoint file to get the consumed samples. assume it is zero."
                )
                init_consumed_samples = 0
        else:
            init_consumed_samples = 0
        self.init_consumed_samples = init_consumed_samples

        if stage == 'predict':
            return
        else:
            # TODO: consider adding a ModelPT guard to check if model is being restored.
            # allowing restored models to optionally setup datasets
            self.build_train_valid_test_datasets()
            self.setup_training_data(self.cfg.data)
            self.setup_validation_data(self.cfg.data)
            self.setup_test_data(self.cfg.data)

        # when using pipeline model parallel the final stage need to initialize word embeddings
        if parallel_state.get_pipeline_model_parallel_world_size() > 1:
            self.model.sync_initial_word_embeddings()
Ejemplo n.º 10
0
def _forward_backward_pipelining_with_interleaving(
    forward_step_func: FwdStepFunc,
    batch: List[Optional[Batch]],
    model: List[torch.nn.Module],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run interleaved 1F1B schedule with communication between pipeline stages as needed.

    This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
    This means that model is split into model chunks.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown
    Note that if `forward_only` this scheduling consists of only warmup phase.

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    if not isinstance(model, list):
        raise RuntimeError("`model` must be a list of `nn.Module`'s'")

    num_model_chunks: int = len(model)
    input_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    output_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
    losses_reduced: List[Union[None, torch.Tensor]] = []
    if not forward_only:
        output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
            [] for _ in range(num_model_chunks)
        ]

    pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size(
    )
    pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank(
    )

    # Compute number of warmup and remaining microbatches.
    num_microbatches: int = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches: bool = False
    if forward_only:
        num_warmup_microbatches: int = num_microbatches
    else:
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = (pipeline_parallel_size -
                                       pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks -
                                        1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    ###################################################################################################################
    # Helper function definitions.
    ###################################################################################################################
    def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
        """Helper function to get the model chunk ID given the iteration number."""
        pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
        )
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size *
                                                  num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
        if not forward:
            model_chunk_id = num_model_chunks - model_chunk_id - 1
        return model_chunk_id

    def forward_step_helper(microbatch_id: int,
                            curr_iters: List[int]) -> torch.Tensor:
        """Helper method to run forward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        # forward step
        if (parallel_state.is_pipeline_first_stage()
                and len(input_tensors[model_chunk_id]) == len(
                    output_tensors[model_chunk_id])):
            input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
        output_tensor = forward_step(
            forward_step_func,
            get_kth_microbatch(batch, curr_iters[model_chunk_id]),
            model[model_chunk_id],
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        curr_iters[model_chunk_id] += 1
        output_tensors[model_chunk_id].append(output_tensor)

        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

        return output_tensor

    def backward_step_helper(microbatch_id: int) -> torch.Tensor:
        """Helper method to run backward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
        model_type = get_model_type(model[model_chunk_id])
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if parallel_state.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = backward_step(
            input_tensor,
            output_tensor,
            output_tensor_grad,
            model_type=model_type,
            grad_scaler=grad_scaler,
            deallocate_pipeline_outputs=deallocate_pipeline_outputs)

        return input_tensor_grad

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
    _logger.info("Warmup phase")
    for k in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
        output_tensor = forward_step_helper(k, curr_iters)

        # Determine if tensor should be received from previous stage.
        next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
        _logger.debug(
            f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
        )

        # Don't send tensor downstream if on last stage.
        if parallel_state.is_pipeline_last_stage():
            _logger.debug("Pipeline last stage, not sending tensor downstream")
            output_tensor = None

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
        if k == (num_warmup_microbatches -
                 1) and not forward_only and not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            _logger.debug("send fwd&bwd and receive fwd&bwd")
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p_communication.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                dtype=dtype,
            )
            output_tensor_grads[num_model_chunks -
                                1].append(output_tensor_grad)
        else:
            _logger.debug("send fwd and receive fwd")
            input_tensor = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                dtype=dtype)
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for k in range(num_microbatches_remaining):
        # Forward pass.
        _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k, curr_iters)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            forward_model_chunk_id)
        if parallel_state.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            backward_model_chunk_id)
        _logger.debug(
            f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
        )
        if parallel_state.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)

        recv_next = True
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)

        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        _logger.debug("send fwd&bwd and receive fwd&bwd")
        (
            input_tensor,
            output_tensor_grad,
        ) = p2p_communication.send_forward_backward_recv_forward_backward(
            output_tensor,
            input_tensor_grad,
            recv_prev=recv_prev,
            recv_next=recv_next,
            tensor_shape=tensor_shape,
            dtype=dtype,
        )
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)

    ###################################################################################################################
    # Run cooldown backward passes (flush out pipeline).
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape=tensor_shape,
                                                dtype=dtype))
        for k in range(num_microbatches_remaining, num_microbatches):
            _logger.debug(
                f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
            )
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k + 1,
                                                              forward=False)
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    dtype=dtype))

    return losses_reduced
Ejemplo n.º 11
0
def run_interleaved_with_dynamic_batch_size(
    pipeline_model_parallel_size: int,
    forward_only: bool,
    BatchSamplerCls,
) -> None:
    args = global_vars.get_args()
    _reconfigure_microbatch_calculator(
        args.rank,
        args.rampup_batch_size,
        args.global_batch_size,
        args.micro_batch_size,
        1,  # args.data_parallel_size,
    )
    virtual_pipeline_model_parallel_size = 2
    # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling
    # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
    # used ubiquitously but this test uses custom model so it's safe to abuse.
    parallel_state.initialize_model_parallel(
        1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
    pipeline_model_parallel_size = (
        parallel_state.get_pipeline_model_parallel_world_size())

    print_separator(
        f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}"
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
        hidden_size=HIDDEN_SIZE,
    )
    assert isinstance(model, list)
    assert len(model) == virtual_pipeline_model_parallel_size
    optimizer = torch.optim.Adam(
        _get_params_for_weight_decay_optimization(model))

    initial_local_minibatch_size = get_num_microbatches() * micro_batch_size
    dataset = Dataset(NUM_SAMPLES)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=BatchSamplerCls(
            NUM_SAMPLES,
            0,
            initial_local_minibatch_size,
            parallel_state.get_data_parallel_rank(),
            parallel_state.get_data_parallel_world_size(),
        ),
    )
    data_iter = iter(data_loader)

    def get_num_samples(batch):
        if isinstance(batch, torch.Tensor):
            return len(batch)
        assert isinstance(batch, (list, tuple))
        return [get_num_samples(b) for b in batch]

    tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE]
    consumed_samples = 0
    for i in range(NUM_ITERATIONS):
        update_num_microbatches(consumed_samples, consistency_check=False)
        local_batch_size = get_num_microbatches() * micro_batch_size
        data_iter._index_sampler.local_minibatch_size = local_batch_size
        local_mini_batch = next(data_iter)

        _logger.info(f"iter: {i} / {NUM_ITERATIONS} "
                     f"local batchsize: {get_num_samples(local_mini_batch)} "
                     f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}")
        _forward_backward_pipelining_with_interleaving(
            fwd_step_func,
            local_mini_batch,
            model,
            forward_only=forward_only,
            tensor_shape=tensor_shape,
        )

        consumed_samples += (parallel_state.get_data_parallel_world_size() *
                             get_num_microbatches() * micro_batch_size)

        if not forward_only:
            for m in model:
                for p in m.parameters():
                    if p.grad is None:
                        raise RuntimeError("grad not found")
            else:
                optimizer.zero_grad(set_to_none=True)

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size() -
        parallel_state.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
    losses_reduced = []

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)
        cur_microbatch = get_kth_microbatch(batch, i)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        _logger.debug("send fwd")
        p2p_communication.send_forward(output_tensor,
                                       tensor_shape=tensor_shape)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration = i == (num_microbatches_remaining - 1)

        cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            _logger.debug("send fwd")
            p2p_communication.send_forward(output_tensor,
                                           tensor_shape=tensor_shape)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = p2p_communication.recv_forward(
                    tensor_shape=tensor_shape)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape=tensor_shape)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                p2p_communication.send_backward(input_tensor_grad,
                                                tensor_shape=tensor_shape)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = p2p_communication.send_backward_recv_forward(
                    input_tensor_grad, tensor_shape=tensor_shape)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = p2p_communication.recv_backward(
                tensor_shape=tensor_shape)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            _logger.debug("send bwd")
            p2p_communication.send_backward(input_tensor_grad,
                                            tensor_shape=tensor_shape)

    return losses_reduced
Ejemplo n.º 13
0
def build_model(
        model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
        wrap_with_ddp: bool = True,
        virtual_pipeline_model_parallel_size: Optional[int] = None,
        *args,
        **kwargs
) -> List[torch.nn.Module]:
    """Build the model satisfying pipeline model parallel requirements.

    This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
    `model_provider_func`.

    Args:
        model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
        wrap_with_ddp: If :obj:`True`, wrap the instantiated model
            with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
        virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
        *args: arguments for model provider func
        **kwargs: Keyword arguments for model provider func

    Returns:
        a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
        the list has multiple models, otherwise one.
    """
    if (
            parallel_state.get_pipeline_model_parallel_world_size() > 1 and
            virtual_pipeline_model_parallel_size is not None
    ):
        model = []
        for i in range(virtual_pipeline_model_parallel_size):
            cur_args = args
            cur_kwargs = kwargs
            parallel_state.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            this_model = model_provider_func(*cur_args, **cur_kwargs)
            model.append(this_model)
    else:
        cur_args = args
        cur_kwargs = kwargs
        pre_process = parallel_state.is_pipeline_first_stage()
        post_process = parallel_state.is_pipeline_last_stage()
        cur_kwargs.update({
            "pre_process": pre_process,
            "post_process": post_process,
        })
        model = model_provider_func(*cur_args, **cur_kwargs)

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if parallel_state.get_data_parallel_rank() == 0:
        msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_pipeline_model_parallel_rank(),
            sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
        )
        print(msg, flush=True)

    # GPU allocation.
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())

    if wrap_with_ddp:
        i = torch.cuda.current_device()
        model = [
            torch.nn.parallel.distributed.DistributedDataParallel(
                model_module,
                device_ids=[i],
                output_device=i,
                process_group=parallel_state.get_data_parallel_group(),
            )
            for model_module in model
        ]
    return model
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Optional[Batch],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    decoder_sequence_length: Optional[int] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwawrgs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model: List[torch.nn.Module] = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model: torch.nn.Module = model[0]

    # Compute number of warmup microbatches.
    num_microbatches: int = get_num_microbatches()
    num_warmup_microbatches: int = (
        parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
    )
    num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    model_type = get_model_type(model)
    rank: int = parallel_state.get_pipeline_model_parallel_rank()
    recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )
    send_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )

    _logger.info(
        f"num_microbatches: {num_microbatches}, "
        f"num_warmup_microbatches: {num_warmup_microbatches}, "
        f"num_microbatches_remaining: {num_microbatches_remaining}"
    )

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors: List[Union[None, torch.Tensor]] = []
    output_tensors: List[Union[None, torch.Tensor]] = []
    losses_reduced: List[Union[None, torch.Tensor]] = []
    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
        output_tensor = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        _logger.debug("send fwd")
        send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration: bool = i == (num_microbatches_remaining - 1)

        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        if forward_only:
            _logger.debug("send fwd")
            send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            _logger.debug("send bwd")
            send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)

    return losses_reduced
Ejemplo n.º 15
0
    def __init__(
        self,
        init_method,
        output_layer_init_method,
        encoder_attn_mask_type,
        vocab_size,
        max_position_embeddings,
        hidden_size,
        ffn_hidden_size,
        num_layers,
        num_tokentypes,
        num_attention_heads,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        add_decoder=False,
        decoder_attn_mask_type=AttnMaskType.causal,
        add_pooler=False,
        pre_process=True,
        post_process=True,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        bias_gelu_fusion=True,
        persist_layer_norm=False,
        openai_gelu=False,
        onnx_safe=False,
        use_soft_prompts=False,
        num_prompt_tokens=10,
        prompt_tags=None,
    ):
        super(TransformerLanguageModel, self).__init__()

        self.pre_process = pre_process
        self.post_process = post_process
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.num_tokentypes = num_tokentypes
        self.init_method = init_method
        self.encoder_attn_mask_type = encoder_attn_mask_type
        self.add_decoder = add_decoder
        self.decoder_attn_mask_type = decoder_attn_mask_type
        self.add_pooler = add_pooler
        self.hidden_dropout = hidden_dropout
        self.output_layer_init_method = output_layer_init_method
        self.use_soft_prompts = use_soft_prompts
        self.prompt_tags = prompt_tags
        self.num_prompt_tokens = num_prompt_tokens

        if kv_channels is None:

            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        # Embeddings.
        if self.pre_process:
            self.embedding = Embedding(
                hidden_size=self.hidden_size,
                vocab_size=self.vocab_size,
                max_sequence_length=self.max_position_embeddings,
                init_method=self.init_method,
                num_tokentypes=self.num_tokentypes,
                use_cpu_initialization=use_cpu_initialization,
                embedding_dropout_prob=self.hidden_dropout,
            )
            self._embedding_key = 'embedding'

        # Soft Prompts
        if self.use_soft_prompts:
            self.prompt_table = PromptTable(
                prompt_tags=self.prompt_tags,
                num_prompt_tokens=self.num_prompt_tokens,
                hidden_size=self.hidden_size,
            )
            self._prompt_table_key = 'prompt_table'

        # Transformer.
        self.encoder = ParallelTransformer(
            init_method=self.init_method,
            output_layer_init_method=self.output_layer_init_method,
            num_layers=self.num_layers,
            hidden_size=self.hidden_size,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            ffn_hidden_size=ffn_hidden_size,
            self_attn_mask_type=self.encoder_attn_mask_type,
            pre_process=self.pre_process,
            post_process=self.post_process,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            hidden_dropout=hidden_dropout,
            use_cpu_initialization=use_cpu_initialization,
            bias_gelu_fusion=bias_gelu_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
        )
        self._encoder_key = 'encoder'

        # Decoder
        if self.add_decoder:
            assert (
                parallel_state.get_pipeline_model_parallel_world_size() == 1
            ), 'pipeline parallelism is not supported in the presence of decoder'
            self.decoder = ParallelTransformer(
                layer_type=LayerType.decoder,
                self_attn_mask_type=self.decoder_attn_mask_type,
                init_method=self.init_method,
                output_layer_init_method=self.output_layer_init_method,
                num_layers=self.num_layers,
                hidden_size=self.hidden_size,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                ffn_hidden_size=ffn_hidden_size,
                pre_process=self.pre_process,
                post_process=self.post_process,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=
                activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                hidden_dropout=hidden_dropout,
                use_cpu_initialization=use_cpu_initialization,
                bias_gelu_fusion=bias_gelu_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
            )
            self._decoder_key = 'decoder'

        if self.post_process:
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'
Ejemplo n.º 16
0
def build_model(
    model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
    wrap_with_ddp: bool = True,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    model_type: ModelType = ModelType.encoder_or_decoder,
    *args: Any,
    **kwargs: Any,
) -> List[torch.nn.Module]:
    """Build the model satisfying pipeline model parallel requirements.

    This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
    `model_provider_func`.

    Args:
        model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
        wrap_with_ddp: If :obj:`True`, wrap the instantiated model
            with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
        virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
        model_type:
        *args: arguments for model provider func
        **kwargs: Keyword arguments for model provider func

    Returns:
        a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
        the list has multiple models, otherwise one.
    """
    if (parallel_state.get_pipeline_model_parallel_world_size() > 1
            and virtual_pipeline_model_parallel_size is not None):
        model = []
        for i in range(virtual_pipeline_model_parallel_size):
            cur_args = args
            cur_kwargs = kwargs
            parallel_state.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            this_model = model_provider_func(*cur_args, **cur_kwargs)
            model.append(this_model)
    else:
        cur_args = args
        cur_kwargs = kwargs
        if model_type == ModelType.encoder_or_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        elif model_type == ModelType.encoder_and_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            # `add_encoder` & `add_decoder` logic.
            add_encoder, add_decoder = True, True
            if parallel_state.get_pipeline_model_parallel_world_size() > 1:
                split_rank = parallel_state.get_pipeline_model_parallel_split_rank(
                )
                if split_rank is None:
                    raise RuntimeError(
                        "Split rank needs to be specified for model with both encoder and decoder."
                    )
                rank = parallel_state.get_pipeline_model_parallel_rank()
                world_size = parallel_state.get_pipeline_model_parallel_world_size(
                )
                pre_process = rank == 0 or rank == split_rank
                post_process = rank == (split_rank -
                                        1) or rank == (world_size - 1)
                add_encoder = parallel_state.is_pipeline_stage_before_split()
                add_decoder = parallel_state.is_pipeline_stage_after_split()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
                "add_encoder": add_encoder,
                "add_decoder": add_decoder,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        model.model_type = model_type

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if (parallel_state.model_parallel_is_initialized()
            and parallel_state.get_data_parallel_rank() == 0):
        msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_pipeline_model_parallel_rank(),
            _calc_number_of_params(model),
        )
        print(msg, flush=True)

    # GPU allocation.
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())

    if wrap_with_ddp:
        i = torch.cuda.current_device()
        model = [
            torch.nn.parallel.distributed.DistributedDataParallel(
                model_module,
                device_ids=[i],
                output_device=i,
                process_group=parallel_state.get_data_parallel_group(),
            ) for model_module in model
        ]
    return model
Ejemplo n.º 17
0
    def __init__(
        self,
        init_method,
        output_layer_init_method,
        num_layers,
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        layer_type=LayerType.encoder,
        self_attn_mask_type=AttnMaskType.padding,
        pre_process=True,
        post_process=True,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.1,
        use_cpu_initialization=False,
        bias_gelu_fusion=True,
        openai_gelu=False,
        onnx_safe=False,
    ):
        super(ParallelTransformer, self).__init__()

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        self.fp32_residual_connection = fp32_residual_connection
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None

        # Store activation checkpointing flag.
        self.activations_checkpoint_method = activations_checkpoint_method
        self.activations_checkpoint_num_layers = activations_checkpoint_num_layers

        # Number of layers.
        assert (
            num_layers %
            parallel_state.get_pipeline_model_parallel_world_size() == 0
        ), 'num_layers must be divisible by pipeline_model_parallel_size'
        self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size(
        )

        # Transformer layers.
        def build_layer(layer_number):
            return ParallelTransformerLayer(
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                layer_number=layer_number,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                layernorm_epsilon=layernorm_epsilon,
                hidden_dropout=hidden_dropout,
                use_cpu_initialization=use_cpu_initialization,
                bias_gelu_fusion=bias_gelu_fusion,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
            )

        # TODO: get virtual_pipeline_model_parallel_size from apex.mpu
        # if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None:
        #     assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
        #         'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size'
        #     )
        #     # Number of layers in each model chunk is the number of layers in the stage,
        #     # divided by the number of model chunks in a stage.
        #     self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
        #     # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
        #     # layers to stages like (each list is a model chunk):
        #     # Stage 0: [0]  [2]  [4]  [6]
        #     # Stage 1: [1]  [3]  [5]  [7]
        #     # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
        #     # layers to stages like (each list is a model chunk):
        #     # Stage 0: [0, 1]  [4, 5]
        #     # Stage 1: [2, 3]  [6, 7]
        #     offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
        #         args.num_layers // args.virtual_pipeline_model_parallel_size
        #     ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
        # else:
        #     # Each stage gets a contiguous set of layers.
        #     offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
        offset = parallel_state.get_pipeline_model_parallel_rank(
        ) * self.num_layers

        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])

        if self.post_process:
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(hidden_size,
                                             eps=layernorm_epsilon)
Ejemplo n.º 18
0
def backward_step(
    input_tensor: Optional[torch.Tensor],
    output_tensor: torch.Tensor,
    output_tensor_grad: Optional[torch.Tensor],
    model_type: ModelType,
    *,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    deallocate_pipeline_outputs: bool = False,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage).

    Args:
        input_tensor:
        output_tensor:
        output_tensor_grad:
    Keyword Arguments:
        grad_scaler:
        deallocate_pipeline_outputs: Experimental.
    Returns:
        input_tensor_grad
    """

    # timers = get_timers()
    # timers("backward-compute").start()

    # Retain the grad on the input_tensor.
    unwrap_input_tensor_grad = not isinstance(input_tensor, list)
    if unwrap_input_tensor_grad:
        input_tensor = [input_tensor]

    input_tensor = [
        inp.get() if isinstance(inp, FutureTensor) else inp
        for inp in input_tensor
    ]

    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]

    output_tensor = [
        out.get() if isinstance(out, FutureTensor) else out
        for out in output_tensor
    ]

    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]

    output_tensor_grad = [
        ogr.get() if isinstance(ogr, FutureTensor) else ogr
        for ogr in output_tensor_grad
    ]

    # Backward pass.
    if grad_scaler is not None and output_tensor_grad[0] is None:
        output_tensor[0] = grad_scaler.scale(output_tensor[0])
    if deallocate_pipeline_outputs:
        custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0],
                                grad_tensors=output_tensor_grad[0])

    # Collect the grad of the input_tensor.
    input_tensor_grad = [None]
    if input_tensor is not None:
        input_tensor_grad = []
        for x in input_tensor:
            input_tensor_grad.append(None if x is None else x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
    if (parallel_state.get_pipeline_model_parallel_world_size() > 1
            and parallel_state.is_pipeline_stage_after_split()
            and model_type == ModelType.encoder_and_decoder):
        if output_tensor_grad[1] is not None:
            # todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
            input_tensor_grad[-1].add_(output_tensor_grad[1])

    # timers("backward-compute").stop()
    return input_tensor_grad[
        0] if unwrap_input_tensor_grad else input_tensor_grad
Ejemplo n.º 19
0
                args.micro_batch_size,
                args.data_parallel_size,  # args.data_parallel_size,
            )
            world_size = torch.distributed.get_world_size()

        print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=args.tensor_model_parallel_size,
            pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
            default_backend="nccl",
            p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
        )

        pipeline_model_parallel_size = (
            parallel_state.get_pipeline_model_parallel_world_size()
        )
        model_parallel_cuda_manual_seed(0)
        model = build_model(
            gpt_model_provider,
            wrap_with_ddp=True,
            virtual_pipeline_model_parallel_size=None,
            cpu_offload=args.cpu_offload,
        )
        assert isinstance(model, list), model
        _param_groups = _get_params_for_weight_decay_optimization(model)
        optim = torch.optim.Adam(_param_groups)
        runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)

        parallel_state.destroy_model_parallel()
    torch.distributed.barrier()
Ejemplo n.º 20
0
def forward_backward_func_template(
    name: str,
    forward_backward_func,
    pipeline_model_parallel_size: int,
    forward_only: bool,
) -> None:
    print_separator(
        f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}"
    )
    virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
    if name == "no_pipelining":
        # note (mkozuki): `forward_backward_no_pipelining` is **NOTE** compatible with
        # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
        # tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
        parallel_state.initialize_model_parallel(1, 1, None)
    else:
        # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
        # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
        # used ubiquitously but this test uses custom model so it's safe to abuse.
        parallel_state.initialize_model_parallel(
            1, pipeline_model_parallel_size,
            virtual_pipeline_model_parallel_size)
        if virtual_pipeline_model_parallel_size is not None:
            # Check the experimental warning message
            get_forward_backward_func(virtual_pipeline_model_parallel_size,
                                      pipeline_model_parallel_size)
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
    )
    assert isinstance(model, list)
    assert len(model) == (1 if virtual_pipeline_model_parallel_size is None
                          else virtual_pipeline_model_parallel_size)
    _param_groups = _get_params_for_weight_decay_optimization(model)
    torch.optim.Adam(_param_groups, lr=1e-4)

    tensor_shape = [
        batch_size // parallel_state.get_data_parallel_world_size(),
        hidden_size
    ]
    batch = (torch.randn(tensor_shape).cuda(), )
    tensor_shape[0] = micro_batch_size

    update_num_microbatches(0)
    forward_backward_func(fwd_step_func,
                          batch,
                          model,
                          forward_only=forward_only,
                          tensor_shape=tensor_shape)

    if not forward_only:
        for m in model:
            for p in m.parameters():
                if p.grad is None:
                    raise RuntimeError("grad not found")
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)